From 6d963fe176134eef8fb5715fbb681366128b94e9 Mon Sep 17 00:00:00 2001 From: beedi Date: Sun, 29 Jun 2025 15:52:07 +0000 Subject: [PATCH 01/18] feat:added encryptionManager class for and generate && recoverKey methods && test files --- requirements.txt | 3 +- src/lighthouseweb3/__init__.py | 19 +- .../functions/encryptionManager/__init__.py | 0 .../functions/encryptionManager/config.py | 2 + .../functions/encryptionManager/generate.py | 87 +++++++++ .../functions/encryptionManager/recoverKey.py | 171 ++++++++++++++++++ tests/tests_encryptionEngine/__init__.py | 0 tests/tests_encryptionEngine/test_generate.py | 79 ++++++++ .../test_recover_key.py | 146 +++++++++++++++ 9 files changed, 505 insertions(+), 2 deletions(-) create mode 100644 src/lighthouseweb3/functions/encryptionManager/__init__.py create mode 100644 src/lighthouseweb3/functions/encryptionManager/config.py create mode 100644 src/lighthouseweb3/functions/encryptionManager/generate.py create mode 100644 src/lighthouseweb3/functions/encryptionManager/recoverKey.py create mode 100644 tests/tests_encryptionEngine/__init__.py create mode 100644 tests/tests_encryptionEngine/test_generate.py create mode 100644 tests/tests_encryptionEngine/test_recover_key.py diff --git a/requirements.txt b/requirements.txt index 13c4766..97fd17c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ charset-normalizer==3.1.0 idna==3.4 requests==2.31.0 urllib3==2.0.2 -eth-account==0.13.7 \ No newline at end of file +eth-account==0.13.7 +cryptography \ No newline at end of file diff --git a/src/lighthouseweb3/__init__.py b/src/lighthouseweb3/__init__.py index b1d8d7c..14bca18 100644 --- a/src/lighthouseweb3/__init__.py +++ b/src/lighthouseweb3/__init__.py @@ -2,6 +2,7 @@ import os import io +from typing import List, Dict, Any from .functions import ( upload as d, deal_status, @@ -16,7 +17,7 @@ remove_ipns_record as removeIpnsRecord, create_wallet as createWallet ) - +from .functions.encryptionManager import generate, recoverKey class Lighthouse: def __init__(self, token: str = ""): @@ -224,3 +225,19 @@ def getTagged(self, tag: str): except Exception as e: raise e +class EncryptionManager: + @staticmethod + def generate(threshold: int, keyCount: int): + try: + return generate.generate(threshold, keyCount) + except Exception as e: + raise e + + + @staticmethod + def recoverKey(keyShards: List[Dict[str, Any]]): + try: + return recoverKey.recoverKey(keyShards) + except Exception as e: + raise e + \ No newline at end of file diff --git a/src/lighthouseweb3/functions/encryptionManager/__init__.py b/src/lighthouseweb3/functions/encryptionManager/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lighthouseweb3/functions/encryptionManager/config.py b/src/lighthouseweb3/functions/encryptionManager/config.py new file mode 100644 index 0000000..bcc79dc --- /dev/null +++ b/src/lighthouseweb3/functions/encryptionManager/config.py @@ -0,0 +1,2 @@ +#A 257-bit prime to accommodate 256-bit secrets +PRIME = 2**256 + 297 \ No newline at end of file diff --git a/src/lighthouseweb3/functions/encryptionManager/generate.py b/src/lighthouseweb3/functions/encryptionManager/generate.py new file mode 100644 index 0000000..c9e6ba4 --- /dev/null +++ b/src/lighthouseweb3/functions/encryptionManager/generate.py @@ -0,0 +1,87 @@ +import secrets +import logging +from typing import Dict, List, Any +from .config import PRIME +logger = logging.getLogger(__name__) + + +def evaluate_polynomial(coefficients: List[int], x: int, prime: int) -> int: + """ + Evaluate a polynomial with given coefficients at point x. + msk[0] is constant term (the secret), msk[1] is x coefficient, etc. + + Args: + coefficients: List of coefficients where coefficients[0] is the constant term + x: Point at which to evaluate the polynomial + prime: Prime number for the finite field + + Returns: + The result of the polynomial evaluation modulo prime + """ + result = 0 + x_power = 1 # x^0 = 1 + + for coefficient in coefficients: + result = (result + coefficient * x_power) % prime + x_power = (x_power * x) % prime + + return result + +async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]: + """ + Generate threshold cryptography key shards using Shamir's Secret Sharing + + Args: + threshold: Minimum number of shards needed to reconstruct the secret + key_count: Total number of key shards to generate + + Returns: + { + "masterKey": "", + "keyShards": [ + { + "key": "", + "index": "" + } + ] + } + """ + logger.info(f"Generating key shards with threshold={threshold}, key_count={key_count}") + + msk=[] + idVec=[] + secVec=[] + + if threshold > key_count: + raise ValueError("key_count must be greater than or equal to threshold") + if threshold < 1 or key_count < 1: + raise ValueError("threshold and key_count must be positive integers") + + + msk = [secrets.randbits(256) for _ in range(threshold)] + master_key = msk[0] + + used_ids = set() + + for i in range(key_count): + while True: + id_vec = secrets.randbits(32) + if id_vec != 0 and id_vec not in used_ids: + idVec.append(id_vec) + used_ids.add(id_vec) + break + + for i in range(key_count): + y = evaluate_polynomial(msk, idVec[i], PRIME) + secVec.append(y) + + result = { + "masterKey": hex(master_key), + "keyShards": [{"key": hex(secVec[i]), "index": hex(idVec[i])} for i in range(key_count)] + } + return result + +if __name__ == "__main__": + import asyncio + result = asyncio.run(generate(threshold=1, key_count=1)) + print(result) \ No newline at end of file diff --git a/src/lighthouseweb3/functions/encryptionManager/recoverKey.py b/src/lighthouseweb3/functions/encryptionManager/recoverKey.py new file mode 100644 index 0000000..8856723 --- /dev/null +++ b/src/lighthouseweb3/functions/encryptionManager/recoverKey.py @@ -0,0 +1,171 @@ +from typing import List, Dict, Any +import logging +from .config import PRIME + +logger = logging.getLogger(__name__) + +from typing import Tuple + +def extended_gcd(a: int, b: int) -> Tuple[int, int, int]: + """Extended Euclidean algorithm to find modular inverse. + + Args: + a: First integer + b: Second integer + + Returns: + A tuple (g, x, y) such that a*x + b*y = g = gcd(a, b) + """ + if a == 0: + return b, 0, 1 + else: + g, y, x = extended_gcd(b % a, a) + return g, x - (b // a) * y, y + +def modinv(a: int, m: int) -> int: + """Find the modular inverse of a mod m.""" + g, x, y = extended_gcd(a, m) + if g != 1: + raise ValueError('Modular inverse does not exist') + else: + return x % m + +def lagrange_interpolation(shares: List[Dict[str, str]], prime: int) -> int: + """ + Reconstruct the secret using Lagrange interpolation. + + Args: + shares: List of dictionaries with 'key' and 'index' fields + prime: The prime number used in the finite field + + Returns: + The reconstructed secret as integer + + Raises: + ValueError: If there are duplicate indices + """ + + points = [] + seen_indices = set() + + for i, share in enumerate(shares): + try: + key_str, index_str = validate_share(share, i) + x = int(index_str, 16) + + if x in seen_indices: + raise ValueError(f"Duplicate share index found: 0x{x:x}") + seen_indices.add(x) + + y = int(key_str, 16) + points.append((x, y)) + except ValueError as e: + raise ValueError(f"Invalid share at position {i}: {e}") + + + secret = 0 + + for i, (x_i, y_i) in enumerate(points): + # Calculate the Lagrange basis polynomial L_i(0) + # Evaluate at x=0 to get the constant term + numerator = 1 + denominator = 1 + + for j, (x_j, _) in enumerate(points): + if i != j: + numerator = (numerator * (-x_j)) % prime + denominator = (denominator * (x_i - x_j)) % prime + + try: + inv_denominator = modinv(denominator, prime) + except ValueError as e: + raise ValueError(f"Error in modular inverse calculation: {e}") + + term = (y_i * numerator * inv_denominator) % prime + secret = (secret + term) % prime + + return secret + +def validate_share(share: Dict[str, str], index: int) -> Tuple[str, str]: + """Validate and normalize a single share. + + Args: + share: Dictionary containing 'key' and 'index' fields + index: Position of the share in the input list (for error messages) + + Returns: + Tuple of (normalized_key, normalized_index) as strings without '0x' prefix + + Raises: + ValueError: If the share is invalid + """ + if not isinstance(share, dict): + raise ValueError(f"Share at index {index} must be a dictionary") + + if 'key' not in share or 'index' not in share: + raise ValueError(f"Share at index {index} is missing required fields 'key' or 'index'") + + key_str = str(share['key']).strip().lower() + index_str = str(share['index']).strip().lower() + + if key_str.startswith('0x'): + key_str = key_str[2:] + if index_str.startswith('0x'): + index_str = index_str[2:] + + + if not key_str: + raise ValueError(f"Empty key in share at index {index}") + if not all(c in '0123456789abcdef' for c in key_str): + raise ValueError(f"Invalid key format in share at index {index}: must be a valid hex string") + + if len(key_str) % 2 != 0: + key_str = '0' + key_str + + if not index_str: + raise ValueError(f"Empty index in share at index {index}") + if not all(c in '0123456789abcdef' for c in index_str): + raise ValueError(f"Invalid index format in share at index {index}: must be a valid hex string") + + index_int = int(index_str, 16) + if not (0 <= index_int <= 0xFFFFFFFF): + raise ValueError(f"Index out of range in share at index {index}: must be between 0 and 2^32-1") + + return key_str, index_str + + +async def recoverKey(keyShards: List[Dict[str, str]]) -> Dict[str, Any]: + """ + Recover the master key from a subset of key shares using Lagrange interpolation. + + Args: + keyShards: List of dictionaries containing 'key' and 'index' fields + + Returns: + { + "masterKey": "", + "error": "" + } + """ + logger.info(f"Attempting to recover master key from {len(keyShards)} shares") + + try: + for i, share in enumerate(keyShards): + validate_share(share, i) + secret = lagrange_interpolation(keyShards, PRIME) + return { + "masterKey": hex(secret), + "error": None + } + except ValueError as e: + logger.error(f"Validation error during key recovery: {str(e)}") + return { + "masterKey": None, + "error": f"Validation error: {str(e)}" + } + except Exception as e: + logger.error(f"Error during key recovery: {str(e)}") + return { + "masterKey": None, + "error": f"Recovery error: {str(e)}" + } diff --git a/tests/tests_encryptionEngine/__init__.py b/tests/tests_encryptionEngine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/tests_encryptionEngine/test_generate.py b/tests/tests_encryptionEngine/test_generate.py new file mode 100644 index 0000000..3408ed3 --- /dev/null +++ b/tests/tests_encryptionEngine/test_generate.py @@ -0,0 +1,79 @@ +import unittest +import asyncio +import logging +from src.lighthouseweb3 import EncryptionManager + +logger = logging.getLogger(__name__) + +class TestGenerate(unittest.TestCase): + """Test cases for the generate module.""" + + def test_generate_basic(self): + """Test basic key generation with default parameters.""" + async def run_test(): + result = await EncryptionManager.generate(threshold=2, keyCount=3) + + self.assertIn('masterKey', result) + self.assertIn('keyShards', result) + + # Check master key format (hex string with 0x prefix) + self.assertIsInstance(result['masterKey'], str) + self.assertTrue(result['masterKey'].startswith('0x')) + self.assertTrue(all(c in '0123456789abcdef' for c in result['masterKey'][2:])) + + # Check key shards + self.assertEqual(len(result['keyShards']), 3) + for shard in result['keyShards']: + self.assertIn('key', shard) + self.assertIn('index', shard) + + # Check key format (hex string with 0x prefix) + self.assertTrue(shard['key'].startswith('0x')) + self.assertTrue(all(c in '0123456789abcdef' for c in shard['key'][2:])) + + # Check index format (hex string with 0x prefix) + self.assertTrue(shard['index'].startswith('0x')) + self.assertTrue(all(c in '0123456789abcdef' for c in shard['index'][2:])) + + return result + + return asyncio.run(run_test()) + + def test_generate_custom_parameters(self): + """Test key generation with custom parameters.""" + async def run_test(): + threshold = 3 + key_count = 5 + + result = await EncryptionManager.generate(threshold=threshold, keyCount=key_count) + + self.assertEqual(len(result['keyShards']), key_count) + + # Check all indices are present and unique + indices = [shard['index'] for shard in result['keyShards']] + self.assertEqual(len(set(indices)), key_count) # All unique + + # Verify all indices are valid hex strings with 0x prefix + for index in indices: + self.assertTrue(index.startswith('0x')) + self.assertTrue(all(c in '0123456789abcdef' for c in index[2:])) + + return result + + return asyncio.run(run_test()) + + def test_invalid_threshold(self): + """Test that invalid threshold raises an error.""" + async def run_test(): + with self.assertRaises(ValueError) as context: + await EncryptionManager.generate(threshold=0, keyCount=3) + self.assertIn("must be positive integers", str(context.exception)) + + with self.assertRaises(ValueError) as context: + await EncryptionManager.generate(threshold=4, keyCount=3) + self.assertIn("must be greater than or equal to threshold", str(context.exception)) + + return asyncio.run(run_test()) + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/tests/tests_encryptionEngine/test_recover_key.py b/tests/tests_encryptionEngine/test_recover_key.py new file mode 100644 index 0000000..e159687 --- /dev/null +++ b/tests/tests_encryptionEngine/test_recover_key.py @@ -0,0 +1,146 @@ +import unittest +import asyncio +import logging +from src.lighthouseweb3 import EncryptionManager + +logger = logging.getLogger(__name__) + +class TestRecoverKey(unittest.TestCase): + """Test cases for the recoverKey module.""" + + def test_empty_shares_list(self): + """Test that recovery fails with empty shares list.""" + async def run_test(): + result = await EncryptionManager.recoverKey([]) + self.assertEqual(result['masterKey'], '0x0') + self.assertIsNone(result['error']) + + return asyncio.run(run_test()) + + + def test_recover_key_with_generated_shares(self): + """Test key recovery with dynamically generated shares.""" + async def run_test(): + + threshold = 3 + key_count = 5 + gen_result = await EncryptionManager.generate(threshold=threshold, keyCount=key_count) + master_key = gen_result['masterKey'] + + shares = gen_result['keyShards'][:threshold] + result = await EncryptionManager.recoverKey(shares) + self.assertEqual(result['masterKey'], master_key) + self.assertIsNone(result['error']) + + for i in range(key_count - threshold + 1): + subset = gen_result['keyShards'][i:i+threshold] + result = await EncryptionManager.recoverKey(subset) + self.assertEqual(result['masterKey'], master_key) + self.assertIsNone(result['error']) + + return result + + return asyncio.run(run_test()) + + def test_recover_key_insufficient_shares(self): + """Test with minimum threshold shares""" + async def run_test(): + threshold = 2 + key_count = 5 + gen_result = await EncryptionManager.generate(threshold=threshold, keyCount=key_count) + master_key = gen_result['masterKey'] + shares = gen_result['keyShards'][:threshold] + result = await EncryptionManager.recoverKey(shares) + self.assertEqual(result['masterKey'], master_key) + self.assertIsNone(result['error']) + + result = await EncryptionManager.recoverKey(gen_result['keyShards']) + self.assertEqual(result['masterKey'], master_key) + self.assertIsNone(result['error']) + + return asyncio.run(run_test()) + + def test_insufficient_shares(self): + """Test with insufficient shares for recovery""" + async def run_test(): + threshold = 3 + key_count = 5 + gen_result = await EncryptionManager.generate(threshold=threshold, keyCount=key_count) + + # Test with one less than threshold (should still work as long as we have at least 2 shares) + result = await EncryptionManager.recoverKey(gen_result['keyShards'][:threshold-1]) + self.assertIsNotNone(result['masterKey']) + self.assertIsNone(result['error']) + + # Test with single share (should still work as long as we have at least 1 share) + result = await EncryptionManager.recoverKey(gen_result['keyShards'][:1]) + self.assertIsNotNone(result['masterKey']) + self.assertIsNone(result['error']) + + return asyncio.run(run_test()) + + def test_various_threshold_combinations(self): + """Test recovery with various threshold and share count combinations""" + async def run_test(): + test_cases = [ + (2, 3), + (3, 5), + (4, 7), + (3, 10), + ] + + for threshold, total in test_cases: + with self.subTest(threshold=threshold, total=total): + gen_result = await EncryptionManager.generate( + threshold=threshold, + keyCount=total + ) + master_key = gen_result['masterKey'] + + shares = gen_result['keyShards'][:threshold] + result = await EncryptionManager.recoverKey(shares) + self.assertEqual(result['masterKey'], master_key) + self.assertIsNone(result['error']) + + result = await EncryptionManager.recoverKey(gen_result['keyShards']) + self.assertEqual(result['masterKey'], master_key) + self.assertIsNone(result['error']) + + import random + subset = random.sample(gen_result['keyShards'], threshold + 1) + result = await EncryptionManager.recoverKey(subset) + self.assertEqual(result['masterKey'], master_key) + self.assertIsNone(result['error']) + + return asyncio.run(run_test()) + + + def test_invalid_share_format(self): + """Test that invalid share formats are handled correctly.""" + async def run_test(): + result = await EncryptionManager.recoverKey(["not a dict", "another invalid"]) + self.assertIsNone(result['masterKey']) + self.assertIn("must be a dictionary", result['error']) + + result = await EncryptionManager.recoverKey([{'key': '123'}, {'key': '456'}]) + self.assertIsNone(result['masterKey']) + self.assertIn("missing required fields 'key' or 'index'", result['error'].lower()) + + result = await EncryptionManager.recoverKey([ + {'key': 'invalidhex', 'index': '1'}, + {'key': 'invalidhex2', 'index': '2'} + ]) + self.assertIsNone(result['masterKey']) + self.assertIn("invalid key format", result['error'].lower()) + + result = await EncryptionManager.recoverKey([ + {'key': 'a' * 63, 'index': 'invalidindex'}, + {'key': 'b' * 63, 'index': 'invalidindex2'} + ]) + self.assertIsNone(result['masterKey']) + self.assertIn("invalid index format", result['error'].lower()) + + return asyncio.run(run_test()) + +if __name__ == '__main__': + unittest.main(verbosity=2) From f5cea273e4e42706bf0988f6acc280b7c5f984da Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Sun, 29 Jun 2025 17:17:43 +0000 Subject: [PATCH 02/18] fix:removed func call from generate.py --- src/lighthouseweb3/functions/encryptionManager/generate.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/lighthouseweb3/functions/encryptionManager/generate.py b/src/lighthouseweb3/functions/encryptionManager/generate.py index c9e6ba4..ab8c2f9 100644 --- a/src/lighthouseweb3/functions/encryptionManager/generate.py +++ b/src/lighthouseweb3/functions/encryptionManager/generate.py @@ -80,8 +80,3 @@ async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]: "keyShards": [{"key": hex(secVec[i]), "index": hex(idVec[i])} for i in range(key_count)] } return result - -if __name__ == "__main__": - import asyncio - result = asyncio.run(generate(threshold=1, key_count=1)) - print(result) \ No newline at end of file From 45c0e718b0018b47dec0785fbce3130f5b747820 Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Sun, 29 Jun 2025 18:59:22 +0000 Subject: [PATCH 03/18] fix:removed cryptography from requirements.txt --- requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 97fd17c..13c4766 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,4 @@ charset-normalizer==3.1.0 idna==3.4 requests==2.31.0 urllib3==2.0.2 -eth-account==0.13.7 -cryptography \ No newline at end of file +eth-account==0.13.7 \ No newline at end of file From 2b2382919930a68b4e4c9f1f21000158a5425327 Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Wed, 2 Jul 2025 18:43:59 +0000 Subject: [PATCH 04/18] fix: updated the variables casing --- src/lighthouseweb3/__init__.py | 7 +++++-- .../encryptionManager/{recoverKey.py => recover_key.py} | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) rename src/lighthouseweb3/functions/encryptionManager/{recoverKey.py => recover_key.py} (98%) diff --git a/src/lighthouseweb3/__init__.py b/src/lighthouseweb3/__init__.py index 14bca18..3a75113 100644 --- a/src/lighthouseweb3/__init__.py +++ b/src/lighthouseweb3/__init__.py @@ -17,7 +17,10 @@ remove_ipns_record as removeIpnsRecord, create_wallet as createWallet ) -from .functions.encryptionManager import generate, recoverKey +from .functions.encryptionManager import ( + generate, + recover_key as recoverKey +) class Lighthouse: def __init__(self, token: str = ""): @@ -237,7 +240,7 @@ def generate(threshold: int, keyCount: int): @staticmethod def recoverKey(keyShards: List[Dict[str, Any]]): try: - return recoverKey.recoverKey(keyShards) + return recoverKey.recover_key(keyShards) except Exception as e: raise e \ No newline at end of file diff --git a/src/lighthouseweb3/functions/encryptionManager/recoverKey.py b/src/lighthouseweb3/functions/encryptionManager/recover_key.py similarity index 98% rename from src/lighthouseweb3/functions/encryptionManager/recoverKey.py rename to src/lighthouseweb3/functions/encryptionManager/recover_key.py index 8856723..f03d54a 100644 --- a/src/lighthouseweb3/functions/encryptionManager/recoverKey.py +++ b/src/lighthouseweb3/functions/encryptionManager/recover_key.py @@ -134,7 +134,7 @@ def validate_share(share: Dict[str, str], index: int) -> Tuple[str, str]: return key_str, index_str -async def recoverKey(keyShards: List[Dict[str, str]]) -> Dict[str, Any]: +async def recover_key(keyShards: List[Dict[str, str]]) -> Dict[str, Any]: """ Recover the master key from a subset of key shares using Lagrange interpolation. From 0783ebe832a5a8bfbde6aaaa62494b14668b95ef Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Wed, 2 Jul 2025 19:12:58 +0000 Subject: [PATCH 05/18] feat:added shard key method --- src/lighthouseweb3/__init__.py | 10 +- .../functions/encryptionManager/generate.py | 66 +++-------- .../functions/encryptionManager/shard_key.py | 105 ++++++++++++++++++ 3 files changed, 127 insertions(+), 54 deletions(-) create mode 100644 src/lighthouseweb3/functions/encryptionManager/shard_key.py diff --git a/src/lighthouseweb3/__init__.py b/src/lighthouseweb3/__init__.py index 3a75113..4e41078 100644 --- a/src/lighthouseweb3/__init__.py +++ b/src/lighthouseweb3/__init__.py @@ -19,7 +19,8 @@ ) from .functions.encryptionManager import ( generate, - recover_key as recoverKey + recover_key as recoverKey, + shard_key as shardKey ) class Lighthouse: @@ -243,4 +244,11 @@ def recoverKey(keyShards: List[Dict[str, Any]]): return recoverKey.recover_key(keyShards) except Exception as e: raise e + + @staticmethod + def shardKey(masterKey: int, threshold: int, keyCount: int): + try: + return shardKey.shard_key(masterKey, threshold, keyCount) + except Exception as e: + raise e \ No newline at end of file diff --git a/src/lighthouseweb3/functions/encryptionManager/generate.py b/src/lighthouseweb3/functions/encryptionManager/generate.py index ab8c2f9..88d11a4 100644 --- a/src/lighthouseweb3/functions/encryptionManager/generate.py +++ b/src/lighthouseweb3/functions/encryptionManager/generate.py @@ -1,32 +1,9 @@ import secrets import logging from typing import Dict, List, Any -from .config import PRIME +from shard_key import shard_key logger = logging.getLogger(__name__) - -def evaluate_polynomial(coefficients: List[int], x: int, prime: int) -> int: - """ - Evaluate a polynomial with given coefficients at point x. - msk[0] is constant term (the secret), msk[1] is x coefficient, etc. - - Args: - coefficients: List of coefficients where coefficients[0] is the constant term - x: Point at which to evaluate the polynomial - prime: Prime number for the finite field - - Returns: - The result of the polynomial evaluation modulo prime - """ - result = 0 - x_power = 1 # x^0 = 1 - - for coefficient in coefficients: - result = (result + coefficient * x_power) % prime - x_power = (x_power * x) % prime - - return result - async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]: """ Generate threshold cryptography key shards using Shamir's Secret Sharing @@ -48,35 +25,18 @@ async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]: """ logger.info(f"Generating key shards with threshold={threshold}, key_count={key_count}") - msk=[] - idVec=[] - secVec=[] - - if threshold > key_count: - raise ValueError("key_count must be greater than or equal to threshold") - if threshold < 1 or key_count < 1: - raise ValueError("threshold and key_count must be positive integers") - + try: + master_key = secrets.randbits(256) + result = await shard_key(master_key, threshold, key_count) - msk = [secrets.randbits(256) for _ in range(threshold)] - master_key = msk[0] + if not result['isShardable']: + raise ValueError(result['error']) - used_ids = set() - - for i in range(key_count): - while True: - id_vec = secrets.randbits(32) - if id_vec != 0 and id_vec not in used_ids: - idVec.append(id_vec) - used_ids.add(id_vec) - break - - for i in range(key_count): - y = evaluate_polynomial(msk, idVec[i], PRIME) - secVec.append(y) + return { + "masterKey": hex(master_key), + "keyShards": result['keyShards'] + } - result = { - "masterKey": hex(master_key), - "keyShards": [{"key": hex(secVec[i]), "index": hex(idVec[i])} for i in range(key_count)] - } - return result + except Exception as e: + logger.error(f"Error during key generation: {str(e)}") + raise e diff --git a/src/lighthouseweb3/functions/encryptionManager/shard_key.py b/src/lighthouseweb3/functions/encryptionManager/shard_key.py new file mode 100644 index 0000000..c316273 --- /dev/null +++ b/src/lighthouseweb3/functions/encryptionManager/shard_key.py @@ -0,0 +1,105 @@ +import secrets +import logging +from typing import Dict, List, Any +from config import PRIME +logger = logging.getLogger(__name__) + + +def evaluate_polynomial(coefficients: List[int], x: int, prime: int) -> int: + """ + Evaluate a polynomial with given coefficients at point x. + msk[0] is constant term (the secret), msk[1] is x coefficient, etc. + + Args: + coefficients: List of coefficients where coefficients[0] is the constant term + x: Point at which to evaluate the polynomial + prime: Prime number for the finite field + + Returns: + { + "isShardable": true, + "keyShards": [ + { "key": "", "index": "" } + ] + } + """ + result = 0 + x_power = 1 # x^0 = 1 + + for coefficient in coefficients: + result = (result + coefficient * x_power) % prime + x_power = (x_power * x) % prime + + return result + +def validate_key(key: str) -> bool: + """ + Validate that the given key is a valid 32-byte (64 hex char) string. + """ + try: + bytes.fromhex(key) + return len(key) == 64 + except ValueError: + return False + +async def shard_key(key: str, threshold: int = 3, key_count: int = 5) -> Dict[str, Any]: + """ + Generate threshold cryptography key shards using Shamir's Secret Sharing + + Args: + key: The key to be shared + threshold: Minimum number of shards needed to reconstruct the secret + key_count: Total number of key shards to generate + + Returns: + { + "isShardable": true, + "keyShards": [ + { + "key": "", + "index": "" + } + ] + } + """ + logger.info(f"Generating key shards with threshold={threshold}, key_count={key_count}") + + try: + msk=[] + idVec=[] + secVec=[] + + if threshold > key_count: + raise ValueError("key_count must be greater than or equal to threshold") + if threshold < 1 or key_count < 1: + raise ValueError("threshold and key_count must be positive integers") + + + msk.append(key) + + used_ids = set() + + for i in range(key_count): + while True: + id_vec = secrets.randbits(32) + if id_vec != 0 and id_vec not in used_ids: + idVec.append(id_vec) + used_ids.add(id_vec) + break + + for i in range(key_count): + y = evaluate_polynomial(msk, idVec[i], PRIME) + secVec.append(y) + + result = { + "isShardable": True, + "keyShards": [{"key": hex(secVec[i]), "index": hex(idVec[i])} for i in range(key_count)] + } + except Exception as e: + logger.error(f"Error generating key shards: {str(e)}") + result = { + "isShardable": False, + "error": str(e) + } + + return result From 94a1670eac3aa998439fca6ef565b05f45290a6f Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Wed, 2 Jul 2025 19:42:38 +0000 Subject: [PATCH 06/18] refactor: updated the hex validation func to use bytes.fromhex --- .../encryptionManager/recover_key.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/lighthouseweb3/functions/encryptionManager/recover_key.py b/src/lighthouseweb3/functions/encryptionManager/recover_key.py index f03d54a..f0a6165 100644 --- a/src/lighthouseweb3/functions/encryptionManager/recover_key.py +++ b/src/lighthouseweb3/functions/encryptionManager/recover_key.py @@ -112,19 +112,26 @@ def validate_share(share: Dict[str, str], index: int) -> Tuple[str, str]: key_str = key_str[2:] if index_str.startswith('0x'): index_str = index_str[2:] - if not key_str: raise ValueError(f"Empty key in share at index {index}") - if not all(c in '0123456789abcdef' for c in key_str): - raise ValueError(f"Invalid key format in share at index {index}: must be a valid hex string") + if not index_str: + raise ValueError(f"Empty index in share at index {index}") if len(key_str) % 2 != 0: key_str = '0' + key_str - - if not index_str: - raise ValueError(f"Empty index in share at index {index}") - if not all(c in '0123456789abcdef' for c in index_str): + + if len(index_str) % 2 != 0: + index_str = '0' + index_str + + try: + bytes.fromhex(key_str) + except ValueError: + raise ValueError(f"Invalid key format in share at index {index}: must be a valid hex string") + + try: + bytes.fromhex(index_str) + except ValueError: raise ValueError(f"Invalid index format in share at index {index}: must be a valid hex string") index_int = int(index_str, 16) @@ -133,7 +140,6 @@ def validate_share(share: Dict[str, str], index: int) -> Tuple[str, str]: return key_str, index_str - async def recover_key(keyShards: List[Dict[str, str]]) -> Dict[str, Any]: """ Recover the master key from a subset of key shares using Lagrange interpolation. @@ -168,4 +174,4 @@ async def recover_key(keyShards: List[Dict[str, str]]) -> Dict[str, Any]: return { "masterKey": None, "error": f"Recovery error: {str(e)}" - } + } \ No newline at end of file From 2044cc2431fc020ad44f54eb4ab3fed71816c7ac Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Wed, 2 Jul 2025 19:45:12 +0000 Subject: [PATCH 07/18] fix:updated the validation hex code --- .../encryptionManager/recover_key.py | 21 ++++++++++++------- .../functions/encryptionManager/shard_key.py | 3 +++ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/lighthouseweb3/functions/encryptionManager/recover_key.py b/src/lighthouseweb3/functions/encryptionManager/recover_key.py index f03d54a..4a082dc 100644 --- a/src/lighthouseweb3/functions/encryptionManager/recover_key.py +++ b/src/lighthouseweb3/functions/encryptionManager/recover_key.py @@ -112,19 +112,26 @@ def validate_share(share: Dict[str, str], index: int) -> Tuple[str, str]: key_str = key_str[2:] if index_str.startswith('0x'): index_str = index_str[2:] - if not key_str: raise ValueError(f"Empty key in share at index {index}") - if not all(c in '0123456789abcdef' for c in key_str): - raise ValueError(f"Invalid key format in share at index {index}: must be a valid hex string") + if not index_str: + raise ValueError(f"Empty index in share at index {index}") if len(key_str) % 2 != 0: key_str = '0' + key_str - - if not index_str: - raise ValueError(f"Empty index in share at index {index}") - if not all(c in '0123456789abcdef' for c in index_str): + + if len(index_str) % 2 != 0: + index_str = '0' + index_str + + try: + bytes.fromhex(key_str) + except ValueError: + raise ValueError(f"Invalid key format in share at index {index}: must be a valid hex string") + + try: + bytes.fromhex(index_str) + except ValueError: raise ValueError(f"Invalid index format in share at index {index}: must be a valid hex string") index_int = int(index_str, 16) diff --git a/src/lighthouseweb3/functions/encryptionManager/shard_key.py b/src/lighthouseweb3/functions/encryptionManager/shard_key.py index c316273..0459037 100644 --- a/src/lighthouseweb3/functions/encryptionManager/shard_key.py +++ b/src/lighthouseweb3/functions/encryptionManager/shard_key.py @@ -64,6 +64,9 @@ async def shard_key(key: str, threshold: int = 3, key_count: int = 5) -> Dict[st """ logger.info(f"Generating key shards with threshold={threshold}, key_count={key_count}") + if not validate_key(key): + raise ValueError("Invalid key format: must be a valid hex string") + try: msk=[] idVec=[] From a108014bafac187bc893a793851a902e6ccf4597 Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Wed, 2 Jul 2025 20:11:07 +0000 Subject: [PATCH 08/18] fix --- src/lighthouseweb3/functions/encryptionManager/generate.py | 6 +++--- src/lighthouseweb3/functions/encryptionManager/shard_key.py | 6 +++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/lighthouseweb3/functions/encryptionManager/generate.py b/src/lighthouseweb3/functions/encryptionManager/generate.py index 88d11a4..e115ad1 100644 --- a/src/lighthouseweb3/functions/encryptionManager/generate.py +++ b/src/lighthouseweb3/functions/encryptionManager/generate.py @@ -1,7 +1,7 @@ import secrets import logging from typing import Dict, List, Any -from shard_key import shard_key +from .shard_key import shard_key logger = logging.getLogger(__name__) async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]: @@ -26,14 +26,14 @@ async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]: logger.info(f"Generating key shards with threshold={threshold}, key_count={key_count}") try: - master_key = secrets.randbits(256) + master_key = hex(secrets.randbits(256)) result = await shard_key(master_key, threshold, key_count) if not result['isShardable']: raise ValueError(result['error']) return { - "masterKey": hex(master_key), + "masterKey": master_key, "keyShards": result['keyShards'] } diff --git a/src/lighthouseweb3/functions/encryptionManager/shard_key.py b/src/lighthouseweb3/functions/encryptionManager/shard_key.py index 0459037..a0be43e 100644 --- a/src/lighthouseweb3/functions/encryptionManager/shard_key.py +++ b/src/lighthouseweb3/functions/encryptionManager/shard_key.py @@ -1,7 +1,7 @@ import secrets import logging from typing import Dict, List, Any -from config import PRIME +from .config import PRIME logger = logging.getLogger(__name__) @@ -37,6 +37,8 @@ def validate_key(key: str) -> bool: Validate that the given key is a valid 32-byte (64 hex char) string. """ try: + if key.startswith('0x'): + key = key[2:] bytes.fromhex(key) return len(key) == 64 except ValueError: @@ -67,6 +69,8 @@ async def shard_key(key: str, threshold: int = 3, key_count: int = 5) -> Dict[st if not validate_key(key): raise ValueError("Invalid key format: must be a valid hex string") + key = int(key, 16) + try: msk=[] idVec=[] From a2f37c575dec579659126f719ed4e131d6c79e7a Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Sun, 6 Jul 2025 18:03:36 +0000 Subject: [PATCH 09/18] not working --- .../functions/encryptionManager/shard_key.py | 2 -- .../test_recover_key.py | 5 ++- .../tests_encryptionEngine/test_shard_key.py | 34 +++++++++++++++++++ 3 files changed, 36 insertions(+), 5 deletions(-) create mode 100644 tests/tests_encryptionEngine/test_shard_key.py diff --git a/src/lighthouseweb3/functions/encryptionManager/shard_key.py b/src/lighthouseweb3/functions/encryptionManager/shard_key.py index a0be43e..149dbd5 100644 --- a/src/lighthouseweb3/functions/encryptionManager/shard_key.py +++ b/src/lighthouseweb3/functions/encryptionManager/shard_key.py @@ -80,8 +80,6 @@ async def shard_key(key: str, threshold: int = 3, key_count: int = 5) -> Dict[st raise ValueError("key_count must be greater than or equal to threshold") if threshold < 1 or key_count < 1: raise ValueError("threshold and key_count must be positive integers") - - msk.append(key) used_ids = set() diff --git a/tests/tests_encryptionEngine/test_recover_key.py b/tests/tests_encryptionEngine/test_recover_key.py index e159687..b1a7202 100644 --- a/tests/tests_encryptionEngine/test_recover_key.py +++ b/tests/tests_encryptionEngine/test_recover_key.py @@ -88,7 +88,6 @@ async def run_test(): (4, 7), (3, 10), ] - for threshold, total in test_cases: with self.subTest(threshold=threshold, total=total): gen_result = await EncryptionManager.generate( @@ -113,8 +112,8 @@ async def run_test(): self.assertIsNone(result['error']) return asyncio.run(run_test()) - - + + def test_invalid_share_format(self): """Test that invalid share formats are handled correctly.""" async def run_test(): diff --git a/tests/tests_encryptionEngine/test_shard_key.py b/tests/tests_encryptionEngine/test_shard_key.py new file mode 100644 index 0000000..7d71971 --- /dev/null +++ b/tests/tests_encryptionEngine/test_shard_key.py @@ -0,0 +1,34 @@ +import unittest +import asyncio +import logging +from src.lighthouseweb3 import EncryptionManager + +logger = logging.getLogger(__name__) + +# class TestShardKey(unittest.TestCase): +# def test_valid_32_byte_key(self): +# """Test with a valid 32-byte key.""" + +# async def run_test(): +# key = "0xb51cde71e810430c9f657dd24d5ba30b17ec1f86e9f671c7f4cb3d888a4680dd" +# result = await EncryptionManager.shardKey(key, threshold=3, keyCount=5) +# self.assertTrue(result['isShardable']) +# self.assertEqual(len(result['keyShards']), 5) + +# for shard in result['keyShards']: +# self.assertIn('key', shard) +# self.assertIn('index', shard) +# self.assertIsInstance(shard['key'], str) +# self.assertIsInstance(shard['index'], str) + +# return asyncio.run(run_test()) + + # def test_invalid_key(self): + # """Test with an invalid key.""" + # async def run_test(): + # key = "e810430c9f657dd24d5ba30b17ec1f86e9f671c7f4cb3d888a4680dd" + # result = await EncryptionManager.shardKey(key, threshold=3, keyCount=5) + # self.assertFalse(result['isShardable']) + # self.assertEqual(result['error'], "Invalid key length") + + # return asyncio.run(run_test()) \ No newline at end of file From e78b5e0fe78b8240282827144a27fa6dc0a4751a Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Sun, 6 Jul 2025 18:55:35 +0000 Subject: [PATCH 10/18] feat: added shard_key method --- .../functions/encryptionManager/generate.py | 17 +- .../encryptionManager/recover_key.py | 3 +- .../test_recover_key.py | 2 +- .../tests_encryptionEngine/test_shard_key.py | 176 +++++++++++++++--- 4 files changed, 162 insertions(+), 36 deletions(-) diff --git a/src/lighthouseweb3/functions/encryptionManager/generate.py b/src/lighthouseweb3/functions/encryptionManager/generate.py index e115ad1..0071b0f 100644 --- a/src/lighthouseweb3/functions/encryptionManager/generate.py +++ b/src/lighthouseweb3/functions/encryptionManager/generate.py @@ -1,7 +1,8 @@ import secrets import logging from typing import Dict, List, Any -from .shard_key import shard_key +from .shard_key import shard_key + logger = logging.getLogger(__name__) async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]: @@ -24,19 +25,21 @@ async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]: } """ logger.info(f"Generating key shards with threshold={threshold}, key_count={key_count}") - + try: - master_key = hex(secrets.randbits(256)) + random_int = secrets.randbits(256) + master_key = f"0x{random_int:064x}" + result = await shard_key(master_key, threshold, key_count) - + if not result['isShardable']: raise ValueError(result['error']) - + return { "masterKey": master_key, "keyShards": result['keyShards'] } - + except Exception as e: logger.error(f"Error during key generation: {str(e)}") - raise e + raise e \ No newline at end of file diff --git a/src/lighthouseweb3/functions/encryptionManager/recover_key.py b/src/lighthouseweb3/functions/encryptionManager/recover_key.py index f0a6165..b2b29b0 100644 --- a/src/lighthouseweb3/functions/encryptionManager/recover_key.py +++ b/src/lighthouseweb3/functions/encryptionManager/recover_key.py @@ -159,8 +159,9 @@ async def recover_key(keyShards: List[Dict[str, str]]) -> Dict[str, Any]: for i, share in enumerate(keyShards): validate_share(share, i) secret = lagrange_interpolation(keyShards, PRIME) + master_key = f"0x{secret:064x}" return { - "masterKey": hex(secret), + "masterKey": master_key, "error": None } except ValueError as e: diff --git a/tests/tests_encryptionEngine/test_recover_key.py b/tests/tests_encryptionEngine/test_recover_key.py index b1a7202..93b53ea 100644 --- a/tests/tests_encryptionEngine/test_recover_key.py +++ b/tests/tests_encryptionEngine/test_recover_key.py @@ -12,7 +12,7 @@ def test_empty_shares_list(self): """Test that recovery fails with empty shares list.""" async def run_test(): result = await EncryptionManager.recoverKey([]) - self.assertEqual(result['masterKey'], '0x0') + self.assertEqual(result['masterKey'], '0x0000000000000000000000000000000000000000000000000000000000000000') self.assertIsNone(result['error']) return asyncio.run(run_test()) diff --git a/tests/tests_encryptionEngine/test_shard_key.py b/tests/tests_encryptionEngine/test_shard_key.py index 7d71971..d5a9eb3 100644 --- a/tests/tests_encryptionEngine/test_shard_key.py +++ b/tests/tests_encryptionEngine/test_shard_key.py @@ -5,30 +5,152 @@ logger = logging.getLogger(__name__) -# class TestShardKey(unittest.TestCase): -# def test_valid_32_byte_key(self): -# """Test with a valid 32-byte key.""" - -# async def run_test(): -# key = "0xb51cde71e810430c9f657dd24d5ba30b17ec1f86e9f671c7f4cb3d888a4680dd" -# result = await EncryptionManager.shardKey(key, threshold=3, keyCount=5) -# self.assertTrue(result['isShardable']) -# self.assertEqual(len(result['keyShards']), 5) - -# for shard in result['keyShards']: -# self.assertIn('key', shard) -# self.assertIn('index', shard) -# self.assertIsInstance(shard['key'], str) -# self.assertIsInstance(shard['index'], str) - -# return asyncio.run(run_test()) - - # def test_invalid_key(self): - # """Test with an invalid key.""" - # async def run_test(): - # key = "e810430c9f657dd24d5ba30b17ec1f86e9f671c7f4cb3d888a4680dd" - # result = await EncryptionManager.shardKey(key, threshold=3, keyCount=5) - # self.assertFalse(result['isShardable']) - # self.assertEqual(result['error'], "Invalid key length") - - # return asyncio.run(run_test()) \ No newline at end of file +class TestShardKey(unittest.TestCase): + """Test cases for the shardKey function.""" + + def test_shardKey_valid_32_byte_key(self): + """Test shardKey with valid 32-byte keys.""" + async def run_test(): + valid_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" + result = await EncryptionManager.shardKey(valid_key, threshold=2, keyCount=3) + + self.assertTrue(result['isShardable']) + self.assertIn('keyShards', result) + self.assertEqual(len(result['keyShards']), 3) + + for shard in result['keyShards']: + self.assertIn('key', shard) + self.assertIn('index', shard) + self.assertTrue(shard['key'].startswith('0x')) + self.assertTrue(shard['index'].startswith('0x')) + self.assertTrue(all(c in '0123456789abcdef' for c in shard['key'][2:])) + self.assertTrue(all(c in '0123456789abcdef' for c in shard['index'][2:])) + + valid_key_with_prefix = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" + result2 = await EncryptionManager.shardKey(valid_key_with_prefix, threshold=2, keyCount=3) + + self.assertTrue(result2['isShardable']) + self.assertEqual(len(result2['keyShards']), 3) + + return result + + return asyncio.run(run_test()) + + def test_shardKey_invalid_keys(self): + """Test shardKey with invalid keys.""" + async def run_test(): + short_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcd" + with self.assertRaises(ValueError) as context: + await EncryptionManager.shardKey(short_key, threshold=2, keyCount=3) + self.assertIn("Invalid key format", str(context.exception)) + + long_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef12" + with self.assertRaises(ValueError) as context: + await EncryptionManager.shardKey(long_key, threshold=2, keyCount=3) + self.assertIn("Invalid key format", str(context.exception)) + + malformed_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdefg" + with self.assertRaises(ValueError) as context: + await EncryptionManager.shardKey(malformed_key, threshold=2, keyCount=3) + self.assertIn("Invalid key format", str(context.exception)) + + with self.assertRaises(ValueError) as context: + await EncryptionManager.shardKey("", threshold=2, keyCount=3) + self.assertIn("Invalid key format", str(context.exception)) + + invalid_hex = "xyz4567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" + with self.assertRaises(ValueError) as context: + await EncryptionManager.shardKey(invalid_hex, threshold=2, keyCount=3) + self.assertIn("Invalid key format", str(context.exception)) + + return asyncio.run(run_test()) + + def test_shardKey_threshold_keyCount_combinations(self): + """Test various threshold and keyCount combinations.""" + async def run_test(): + valid_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" + + result1 = await EncryptionManager.shardKey(valid_key, threshold=1, keyCount=1) + self.assertTrue(result1['isShardable']) + self.assertEqual(len(result1['keyShards']), 1) + + result2 = await EncryptionManager.shardKey(valid_key, threshold=2, keyCount=3) + self.assertTrue(result2['isShardable']) + self.assertEqual(len(result2['keyShards']), 3) + + result3 = await EncryptionManager.shardKey(valid_key, threshold=3, keyCount=5) + self.assertTrue(result3['isShardable']) + self.assertEqual(len(result3['keyShards']), 5) + + result4 = await EncryptionManager.shardKey(valid_key, threshold=4, keyCount=4) + self.assertTrue(result4['isShardable']) + self.assertEqual(len(result4['keyShards']), 4) + + result5 = await EncryptionManager.shardKey(valid_key, threshold=5, keyCount=10) + self.assertTrue(result5['isShardable']) + self.assertEqual(len(result5['keyShards']), 10) + + indices = [shard['index'] for shard in result5['keyShards']] + self.assertEqual(len(set(indices)), 10) + + return result5 + + return asyncio.run(run_test()) + + + def test_shardKey_index_uniqueness(self): + """Test that all generated indices are unique and non-zero.""" + async def run_test(): + valid_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" + + + result = await EncryptionManager.shardKey(valid_key, threshold=3, keyCount=20) + + self.assertTrue(result['isShardable']) + self.assertEqual(len(result['keyShards']), 20) + + indices = [shard['index'] for shard in result['keyShards']] + self.assertEqual(len(set(indices)), 20) + + for index in indices: + self.assertNotEqual(index, '0x0') + + self.assertNotEqual(int(index, 16), 0) + + return result + + return asyncio.run(run_test()) + + def test_shardKey_hex_format_consistency(self): + """Test that all returned values are properly formatted hex strings.""" + async def run_test(): + valid_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" + + result = await EncryptionManager.shardKey(valid_key, threshold=2, keyCount=4) + + self.assertTrue(result['isShardable']) + + for shard in result['keyShards']: + key = shard['key'] + index = shard['index'] + + + self.assertTrue(key.startswith('0x')) + self.assertTrue(index.startswith('0x')) + + + self.assertTrue(all(c in '0123456789abcdef' for c in key[2:])) + self.assertTrue(all(c in '0123456789abcdef' for c in index[2:])) + + try: + int(key, 16) + int(index, 16) + except ValueError: + self.fail(f"Invalid hex format: key={key}, index={index}") + + return result + + return asyncio.run(run_test()) + +if __name__ == '__main__': + unittest.main(verbosity=2) \ No newline at end of file From c1080015fc1504152493ec41484463f94e30fd60 Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Sun, 6 Jul 2025 19:17:39 +0000 Subject: [PATCH 11/18] fix:fixed error --- .../functions/encryptionManager/generate.py | 2 +- .../encryptionManager/recover_key.py | 2 +- .../functions/encryptionManager/shard_key.py | 35 ++++++++++--------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/lighthouseweb3/functions/encryptionManager/generate.py b/src/lighthouseweb3/functions/encryptionManager/generate.py index 0071b0f..bc700e8 100644 --- a/src/lighthouseweb3/functions/encryptionManager/generate.py +++ b/src/lighthouseweb3/functions/encryptionManager/generate.py @@ -42,4 +42,4 @@ async def generate(threshold: int = 3, key_count: int = 5) -> Dict[str, Any]: except Exception as e: logger.error(f"Error during key generation: {str(e)}") - raise e \ No newline at end of file + raise e diff --git a/src/lighthouseweb3/functions/encryptionManager/recover_key.py b/src/lighthouseweb3/functions/encryptionManager/recover_key.py index b2b29b0..a7970b7 100644 --- a/src/lighthouseweb3/functions/encryptionManager/recover_key.py +++ b/src/lighthouseweb3/functions/encryptionManager/recover_key.py @@ -175,4 +175,4 @@ async def recover_key(keyShards: List[Dict[str, str]]) -> Dict[str, Any]: return { "masterKey": None, "error": f"Recovery error: {str(e)}" - } \ No newline at end of file + } diff --git a/src/lighthouseweb3/functions/encryptionManager/shard_key.py b/src/lighthouseweb3/functions/encryptionManager/shard_key.py index 149dbd5..4910965 100644 --- a/src/lighthouseweb3/functions/encryptionManager/shard_key.py +++ b/src/lighthouseweb3/functions/encryptionManager/shard_key.py @@ -2,13 +2,13 @@ import logging from typing import Dict, List, Any from .config import PRIME -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) def evaluate_polynomial(coefficients: List[int], x: int, prime: int) -> int: """ Evaluate a polynomial with given coefficients at point x. - msk[0] is constant term (the secret), msk[1] is x coefficient, etc. + coefficients[0] is constant term (the secret), coefficients[1] is x coefficient, etc. Args: coefficients: List of coefficients where coefficients[0] is the constant term @@ -16,15 +16,10 @@ def evaluate_polynomial(coefficients: List[int], x: int, prime: int) -> int: prime: Prime number for the finite field Returns: - { - "isShardable": true, - "keyShards": [ - { "key": "", "index": "" } - ] - } + Value of polynomial at point x """ result = 0 - x_power = 1 # x^0 = 1 + x_power = 1 # x^0 = 1 for coefficient in coefficients: result = (result + coefficient * x_power) % prime @@ -69,29 +64,33 @@ async def shard_key(key: str, threshold: int = 3, key_count: int = 5) -> Dict[st if not validate_key(key): raise ValueError("Invalid key format: must be a valid hex string") - key = int(key, 16) + key_int = int(key, 16) try: - msk=[] - idVec=[] - secVec=[] - if threshold > key_count: raise ValueError("key_count must be greater than or equal to threshold") if threshold < 1 or key_count < 1: raise ValueError("threshold and key_count must be positive integers") - msk.append(key) - + + msk = [key_int] + + for i in range(threshold - 1): + random_coeff = secrets.randbelow(PRIME) + msk.append(random_coeff) + + idVec = [] used_ids = set() for i in range(key_count): while True: id_vec = secrets.randbits(32) - if id_vec != 0 and id_vec not in used_ids: + + if id_vec != 0 and id_vec not in used_ids and id_vec < PRIME: idVec.append(id_vec) used_ids.add(id_vec) break + secVec = [] for i in range(key_count): y = evaluate_polynomial(msk, idVec[i], PRIME) secVec.append(y) @@ -100,6 +99,7 @@ async def shard_key(key: str, threshold: int = 3, key_count: int = 5) -> Dict[st "isShardable": True, "keyShards": [{"key": hex(secVec[i]), "index": hex(idVec[i])} for i in range(key_count)] } + except Exception as e: logger.error(f"Error generating key shards: {str(e)}") result = { @@ -108,3 +108,4 @@ async def shard_key(key: str, threshold: int = 3, key_count: int = 5) -> Dict[st } return result + From 400de7ac64f948b5f0cccb3407b7e7102b12be02 Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Thu, 10 Jul 2025 16:44:50 +0000 Subject: [PATCH 12/18] refactor: changed class name --- src/lighthouseweb3/__init__.py | 4 +- .../{encryptionManager => kavach}/__init__.py | 0 .../{encryptionManager => kavach}/config.py | 0 .../{encryptionManager => kavach}/generate.py | 0 .../recover_key.py | 0 .../shard_key.py | 0 .../__init__.py | 0 .../test_generate.py | 10 ++--- .../test_recover_key.py | 38 +++++++++---------- .../test_shard_key.py | 30 +++++++-------- 10 files changed, 41 insertions(+), 41 deletions(-) rename src/lighthouseweb3/functions/{encryptionManager => kavach}/__init__.py (100%) rename src/lighthouseweb3/functions/{encryptionManager => kavach}/config.py (100%) rename src/lighthouseweb3/functions/{encryptionManager => kavach}/generate.py (100%) rename src/lighthouseweb3/functions/{encryptionManager => kavach}/recover_key.py (100%) rename src/lighthouseweb3/functions/{encryptionManager => kavach}/shard_key.py (100%) rename tests/{tests_encryptionEngine => tests_kavach}/__init__.py (100%) rename tests/{tests_encryptionEngine => tests_kavach}/test_generate.py (88%) rename tests/{tests_encryptionEngine => tests_kavach}/test_recover_key.py (76%) rename tests/{tests_encryptionEngine => tests_kavach}/test_shard_key.py (81%) diff --git a/src/lighthouseweb3/__init__.py b/src/lighthouseweb3/__init__.py index 4e41078..6078bba 100644 --- a/src/lighthouseweb3/__init__.py +++ b/src/lighthouseweb3/__init__.py @@ -17,7 +17,7 @@ remove_ipns_record as removeIpnsRecord, create_wallet as createWallet ) -from .functions.encryptionManager import ( +from .functions.kavach import ( generate, recover_key as recoverKey, shard_key as shardKey @@ -229,7 +229,7 @@ def getTagged(self, tag: str): except Exception as e: raise e -class EncryptionManager: +class Kavach: @staticmethod def generate(threshold: int, keyCount: int): try: diff --git a/src/lighthouseweb3/functions/encryptionManager/__init__.py b/src/lighthouseweb3/functions/kavach/__init__.py similarity index 100% rename from src/lighthouseweb3/functions/encryptionManager/__init__.py rename to src/lighthouseweb3/functions/kavach/__init__.py diff --git a/src/lighthouseweb3/functions/encryptionManager/config.py b/src/lighthouseweb3/functions/kavach/config.py similarity index 100% rename from src/lighthouseweb3/functions/encryptionManager/config.py rename to src/lighthouseweb3/functions/kavach/config.py diff --git a/src/lighthouseweb3/functions/encryptionManager/generate.py b/src/lighthouseweb3/functions/kavach/generate.py similarity index 100% rename from src/lighthouseweb3/functions/encryptionManager/generate.py rename to src/lighthouseweb3/functions/kavach/generate.py diff --git a/src/lighthouseweb3/functions/encryptionManager/recover_key.py b/src/lighthouseweb3/functions/kavach/recover_key.py similarity index 100% rename from src/lighthouseweb3/functions/encryptionManager/recover_key.py rename to src/lighthouseweb3/functions/kavach/recover_key.py diff --git a/src/lighthouseweb3/functions/encryptionManager/shard_key.py b/src/lighthouseweb3/functions/kavach/shard_key.py similarity index 100% rename from src/lighthouseweb3/functions/encryptionManager/shard_key.py rename to src/lighthouseweb3/functions/kavach/shard_key.py diff --git a/tests/tests_encryptionEngine/__init__.py b/tests/tests_kavach/__init__.py similarity index 100% rename from tests/tests_encryptionEngine/__init__.py rename to tests/tests_kavach/__init__.py diff --git a/tests/tests_encryptionEngine/test_generate.py b/tests/tests_kavach/test_generate.py similarity index 88% rename from tests/tests_encryptionEngine/test_generate.py rename to tests/tests_kavach/test_generate.py index 3408ed3..386fe79 100644 --- a/tests/tests_encryptionEngine/test_generate.py +++ b/tests/tests_kavach/test_generate.py @@ -1,7 +1,7 @@ import unittest import asyncio import logging -from src.lighthouseweb3 import EncryptionManager +from src.lighthouseweb3 import Kavach logger = logging.getLogger(__name__) @@ -11,7 +11,7 @@ class TestGenerate(unittest.TestCase): def test_generate_basic(self): """Test basic key generation with default parameters.""" async def run_test(): - result = await EncryptionManager.generate(threshold=2, keyCount=3) + result = await Kavach.generate(threshold=2, keyCount=3) self.assertIn('masterKey', result) self.assertIn('keyShards', result) @@ -45,7 +45,7 @@ async def run_test(): threshold = 3 key_count = 5 - result = await EncryptionManager.generate(threshold=threshold, keyCount=key_count) + result = await Kavach.generate(threshold=threshold, keyCount=key_count) self.assertEqual(len(result['keyShards']), key_count) @@ -66,11 +66,11 @@ def test_invalid_threshold(self): """Test that invalid threshold raises an error.""" async def run_test(): with self.assertRaises(ValueError) as context: - await EncryptionManager.generate(threshold=0, keyCount=3) + await Kavach.generate(threshold=0, keyCount=3) self.assertIn("must be positive integers", str(context.exception)) with self.assertRaises(ValueError) as context: - await EncryptionManager.generate(threshold=4, keyCount=3) + await Kavach.generate(threshold=4, keyCount=3) self.assertIn("must be greater than or equal to threshold", str(context.exception)) return asyncio.run(run_test()) diff --git a/tests/tests_encryptionEngine/test_recover_key.py b/tests/tests_kavach/test_recover_key.py similarity index 76% rename from tests/tests_encryptionEngine/test_recover_key.py rename to tests/tests_kavach/test_recover_key.py index 93b53ea..3ea2d30 100644 --- a/tests/tests_encryptionEngine/test_recover_key.py +++ b/tests/tests_kavach/test_recover_key.py @@ -1,7 +1,7 @@ import unittest import asyncio import logging -from src.lighthouseweb3 import EncryptionManager +from src.lighthouseweb3 import Kavach logger = logging.getLogger(__name__) @@ -11,7 +11,7 @@ class TestRecoverKey(unittest.TestCase): def test_empty_shares_list(self): """Test that recovery fails with empty shares list.""" async def run_test(): - result = await EncryptionManager.recoverKey([]) + result = await Kavach.recoverKey([]) self.assertEqual(result['masterKey'], '0x0000000000000000000000000000000000000000000000000000000000000000') self.assertIsNone(result['error']) @@ -24,17 +24,17 @@ async def run_test(): threshold = 3 key_count = 5 - gen_result = await EncryptionManager.generate(threshold=threshold, keyCount=key_count) + gen_result = await Kavach.generate(threshold=threshold, keyCount=key_count) master_key = gen_result['masterKey'] shares = gen_result['keyShards'][:threshold] - result = await EncryptionManager.recoverKey(shares) + result = await Kavach.recoverKey(shares) self.assertEqual(result['masterKey'], master_key) self.assertIsNone(result['error']) for i in range(key_count - threshold + 1): subset = gen_result['keyShards'][i:i+threshold] - result = await EncryptionManager.recoverKey(subset) + result = await Kavach.recoverKey(subset) self.assertEqual(result['masterKey'], master_key) self.assertIsNone(result['error']) @@ -47,14 +47,14 @@ def test_recover_key_insufficient_shares(self): async def run_test(): threshold = 2 key_count = 5 - gen_result = await EncryptionManager.generate(threshold=threshold, keyCount=key_count) + gen_result = await Kavach.generate(threshold=threshold, keyCount=key_count) master_key = gen_result['masterKey'] shares = gen_result['keyShards'][:threshold] - result = await EncryptionManager.recoverKey(shares) + result = await Kavach.recoverKey(shares) self.assertEqual(result['masterKey'], master_key) self.assertIsNone(result['error']) - result = await EncryptionManager.recoverKey(gen_result['keyShards']) + result = await Kavach.recoverKey(gen_result['keyShards']) self.assertEqual(result['masterKey'], master_key) self.assertIsNone(result['error']) @@ -65,15 +65,15 @@ def test_insufficient_shares(self): async def run_test(): threshold = 3 key_count = 5 - gen_result = await EncryptionManager.generate(threshold=threshold, keyCount=key_count) + gen_result = await Kavach.generate(threshold=threshold, keyCount=key_count) # Test with one less than threshold (should still work as long as we have at least 2 shares) - result = await EncryptionManager.recoverKey(gen_result['keyShards'][:threshold-1]) + result = await Kavach.recoverKey(gen_result['keyShards'][:threshold-1]) self.assertIsNotNone(result['masterKey']) self.assertIsNone(result['error']) # Test with single share (should still work as long as we have at least 1 share) - result = await EncryptionManager.recoverKey(gen_result['keyShards'][:1]) + result = await Kavach.recoverKey(gen_result['keyShards'][:1]) self.assertIsNotNone(result['masterKey']) self.assertIsNone(result['error']) @@ -90,24 +90,24 @@ async def run_test(): ] for threshold, total in test_cases: with self.subTest(threshold=threshold, total=total): - gen_result = await EncryptionManager.generate( + gen_result = await Kavach.generate( threshold=threshold, keyCount=total ) master_key = gen_result['masterKey'] shares = gen_result['keyShards'][:threshold] - result = await EncryptionManager.recoverKey(shares) + result = await Kavach.recoverKey(shares) self.assertEqual(result['masterKey'], master_key) self.assertIsNone(result['error']) - result = await EncryptionManager.recoverKey(gen_result['keyShards']) + result = await Kavach.recoverKey(gen_result['keyShards']) self.assertEqual(result['masterKey'], master_key) self.assertIsNone(result['error']) import random subset = random.sample(gen_result['keyShards'], threshold + 1) - result = await EncryptionManager.recoverKey(subset) + result = await Kavach.recoverKey(subset) self.assertEqual(result['masterKey'], master_key) self.assertIsNone(result['error']) @@ -117,22 +117,22 @@ async def run_test(): def test_invalid_share_format(self): """Test that invalid share formats are handled correctly.""" async def run_test(): - result = await EncryptionManager.recoverKey(["not a dict", "another invalid"]) + result = await Kavach.recoverKey(["not a dict", "another invalid"]) self.assertIsNone(result['masterKey']) self.assertIn("must be a dictionary", result['error']) - result = await EncryptionManager.recoverKey([{'key': '123'}, {'key': '456'}]) + result = await Kavach.recoverKey([{'key': '123'}, {'key': '456'}]) self.assertIsNone(result['masterKey']) self.assertIn("missing required fields 'key' or 'index'", result['error'].lower()) - result = await EncryptionManager.recoverKey([ + result = await Kavach.recoverKey([ {'key': 'invalidhex', 'index': '1'}, {'key': 'invalidhex2', 'index': '2'} ]) self.assertIsNone(result['masterKey']) self.assertIn("invalid key format", result['error'].lower()) - result = await EncryptionManager.recoverKey([ + result = await Kavach.recoverKey([ {'key': 'a' * 63, 'index': 'invalidindex'}, {'key': 'b' * 63, 'index': 'invalidindex2'} ]) diff --git a/tests/tests_encryptionEngine/test_shard_key.py b/tests/tests_kavach/test_shard_key.py similarity index 81% rename from tests/tests_encryptionEngine/test_shard_key.py rename to tests/tests_kavach/test_shard_key.py index d5a9eb3..c3aa209 100644 --- a/tests/tests_encryptionEngine/test_shard_key.py +++ b/tests/tests_kavach/test_shard_key.py @@ -1,7 +1,7 @@ import unittest import asyncio import logging -from src.lighthouseweb3 import EncryptionManager +from src.lighthouseweb3 import Kavach logger = logging.getLogger(__name__) @@ -12,7 +12,7 @@ def test_shardKey_valid_32_byte_key(self): """Test shardKey with valid 32-byte keys.""" async def run_test(): valid_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" - result = await EncryptionManager.shardKey(valid_key, threshold=2, keyCount=3) + result = await Kavach.shardKey(valid_key, threshold=2, keyCount=3) self.assertTrue(result['isShardable']) self.assertIn('keyShards', result) @@ -27,7 +27,7 @@ async def run_test(): self.assertTrue(all(c in '0123456789abcdef' for c in shard['index'][2:])) valid_key_with_prefix = "0x1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" - result2 = await EncryptionManager.shardKey(valid_key_with_prefix, threshold=2, keyCount=3) + result2 = await Kavach.shardKey(valid_key_with_prefix, threshold=2, keyCount=3) self.assertTrue(result2['isShardable']) self.assertEqual(len(result2['keyShards']), 3) @@ -41,26 +41,26 @@ def test_shardKey_invalid_keys(self): async def run_test(): short_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcd" with self.assertRaises(ValueError) as context: - await EncryptionManager.shardKey(short_key, threshold=2, keyCount=3) + await Kavach.shardKey(short_key, threshold=2, keyCount=3) self.assertIn("Invalid key format", str(context.exception)) long_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef12" with self.assertRaises(ValueError) as context: - await EncryptionManager.shardKey(long_key, threshold=2, keyCount=3) + await Kavach.shardKey(long_key, threshold=2, keyCount=3) self.assertIn("Invalid key format", str(context.exception)) malformed_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdefg" with self.assertRaises(ValueError) as context: - await EncryptionManager.shardKey(malformed_key, threshold=2, keyCount=3) + await Kavach.shardKey(malformed_key, threshold=2, keyCount=3) self.assertIn("Invalid key format", str(context.exception)) with self.assertRaises(ValueError) as context: - await EncryptionManager.shardKey("", threshold=2, keyCount=3) + await Kavach.shardKey("", threshold=2, keyCount=3) self.assertIn("Invalid key format", str(context.exception)) invalid_hex = "xyz4567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" with self.assertRaises(ValueError) as context: - await EncryptionManager.shardKey(invalid_hex, threshold=2, keyCount=3) + await Kavach.shardKey(invalid_hex, threshold=2, keyCount=3) self.assertIn("Invalid key format", str(context.exception)) return asyncio.run(run_test()) @@ -70,23 +70,23 @@ def test_shardKey_threshold_keyCount_combinations(self): async def run_test(): valid_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" - result1 = await EncryptionManager.shardKey(valid_key, threshold=1, keyCount=1) + result1 = await Kavach.shardKey(valid_key, threshold=1, keyCount=1) self.assertTrue(result1['isShardable']) self.assertEqual(len(result1['keyShards']), 1) - result2 = await EncryptionManager.shardKey(valid_key, threshold=2, keyCount=3) + result2 = await Kavach.shardKey(valid_key, threshold=2, keyCount=3) self.assertTrue(result2['isShardable']) self.assertEqual(len(result2['keyShards']), 3) - result3 = await EncryptionManager.shardKey(valid_key, threshold=3, keyCount=5) + result3 = await Kavach.shardKey(valid_key, threshold=3, keyCount=5) self.assertTrue(result3['isShardable']) self.assertEqual(len(result3['keyShards']), 5) - result4 = await EncryptionManager.shardKey(valid_key, threshold=4, keyCount=4) + result4 = await Kavach.shardKey(valid_key, threshold=4, keyCount=4) self.assertTrue(result4['isShardable']) self.assertEqual(len(result4['keyShards']), 4) - result5 = await EncryptionManager.shardKey(valid_key, threshold=5, keyCount=10) + result5 = await Kavach.shardKey(valid_key, threshold=5, keyCount=10) self.assertTrue(result5['isShardable']) self.assertEqual(len(result5['keyShards']), 10) @@ -104,7 +104,7 @@ async def run_test(): valid_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" - result = await EncryptionManager.shardKey(valid_key, threshold=3, keyCount=20) + result = await Kavach.shardKey(valid_key, threshold=3, keyCount=20) self.assertTrue(result['isShardable']) self.assertEqual(len(result['keyShards']), 20) @@ -126,7 +126,7 @@ def test_shardKey_hex_format_consistency(self): async def run_test(): valid_key = "1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef" - result = await EncryptionManager.shardKey(valid_key, threshold=2, keyCount=4) + result = await Kavach.shardKey(valid_key, threshold=2, keyCount=4) self.assertTrue(result['isShardable']) From 291f3e9ba625e81ee9da3d3505b0bd6fe48a903f Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Fri, 11 Jul 2025 16:01:27 +0000 Subject: [PATCH 13/18] feat:added util and validator --- requirements.txt | 3 +- src/lighthouseweb3/functions/config.py | 3 + .../functions/kavach/access_control/main.py | 1 + .../kavach/access_control/validator.py | 290 ++++++++++++++++++ src/lighthouseweb3/functions/kavach/util.py | 103 +++++++ 5 files changed, 399 insertions(+), 1 deletion(-) create mode 100644 src/lighthouseweb3/functions/kavach/access_control/main.py create mode 100644 src/lighthouseweb3/functions/kavach/access_control/validator.py create mode 100644 src/lighthouseweb3/functions/kavach/util.py diff --git a/requirements.txt b/requirements.txt index 13c4766..7b040aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ charset-normalizer==3.1.0 idna==3.4 requests==2.31.0 urllib3==2.0.2 -eth-account==0.13.7 \ No newline at end of file +eth-account==0.13.7 +aiohttp \ No newline at end of file diff --git a/src/lighthouseweb3/functions/config.py b/src/lighthouseweb3/functions/config.py index 000c5ef..2f62d25 100644 --- a/src/lighthouseweb3/functions/config.py +++ b/src/lighthouseweb3/functions/config.py @@ -9,3 +9,6 @@ class Config: lighthouse_node = "https://node.lighthouse.storage" lighthouse_bls_node = "https://encryption.lighthouse.storage" lighthouse_gateway = "https://gateway.lighthouse.storage/ipfs" + + is_dev = False + lighthouse_bls_node_dev = "http://enctest.lighthouse.storage" \ No newline at end of file diff --git a/src/lighthouseweb3/functions/kavach/access_control/main.py b/src/lighthouseweb3/functions/kavach/access_control/main.py new file mode 100644 index 0000000..b2ed25b --- /dev/null +++ b/src/lighthouseweb3/functions/kavach/access_control/main.py @@ -0,0 +1 @@ +from .validator import updateConditionSchema, accessConditionSchema \ No newline at end of file diff --git a/src/lighthouseweb3/functions/kavach/access_control/validator.py b/src/lighthouseweb3/functions/kavach/access_control/validator.py new file mode 100644 index 0000000..5afdce7 --- /dev/null +++ b/src/lighthouseweb3/functions/kavach/access_control/validator.py @@ -0,0 +1,290 @@ +from typing import List, Optional, Union, Any +from pydantic import BaseModel, Field, validator, root_validator +from enum import Enum +import re + +SOLIDITY_TYPES = [ + "address", "address[]", "bool", "bool[]", + "bytes1", "bytes2", "bytes3", "bytes4", "bytes5", "bytes6", "bytes7", "bytes8", "bytes16", "bytes32", + "bytes1[]", "bytes2[]", "bytes3[]", "bytes4[]", "bytes5[]", "bytes6[]", "bytes7[]", "bytes8[]", "bytes16[]", "bytes32[]", + "uint8", "uint16", "uint24", "uint32", "uint40", "uint48", "uint64", "uint128", "uint192", "uint256", + "int8", "int16", "int24", "int32", "int40", "int48", "int64", "int128", "int192", "int256", + "uint8[]", "uint16[]", "uint24[]", "uint32[]", "uint40[]", "uint48[]", "uint64[]", "uint128[]", "uint192[]", "uint256[]", + "int8[]", "int16[]", "int24[]", "int32[]", "int40[]", "int48[]", "int64[]", "int128[]", "int192[]", "int256[]" +] + +SUPPORTED_CHAINS = { + "EVM": [], + "SOLANA": ["DEVNET", "TESTNET", "MAINNET"], + "COREUM": ["Coreum_Devnet", "Coreum_Testnet", "Coreum_Mainnet"], + "RADIX": ["Radix_Mainnet"] +} + +class ChainType(str, Enum): + EVM = "EVM" + SOLANA = "SOLANA" + COREUM = "COREUM" + RADIX = "RADIX" + +class DecryptionType(str, Enum): + ADDRESS = "ADDRESS" + ACCESS_CONDITIONS = "ACCESS_CONDITIONS" + +class StandardContractType(str, Enum): + ERC20 = "ERC20" + ERC721 = "ERC721" + ERC1155 = "ERC1155" + CUSTOM = "Custom" + EMPTY = "" + +class SolanaContractType(str, Enum): + SPL_TOKEN = "spl-token" + EMPTY = "" + +class Comparator(str, Enum): + EQUAL = "==" + GREATER_EQUAL = ">=" + LESS_EQUAL = "<=" + NOT_EQUAL = "!=" + GREATER = ">" + LESS = "<" + +class ReturnValueTest(BaseModel): + comparator: Comparator + value: Union[int, float, str, List[Any]] + +class PDAInterface(BaseModel): + offset: int = Field(ge=0) + selector: str + +class EVMCondition(BaseModel): + id: int = Field(ge=1) + standard_contract_type: StandardContractType = Field(alias="standardContractType") + contract_address: Optional[str] = Field(alias="contractAddress") + chain: str + method: str + parameters: Optional[List[Any]] = [] + return_value_test: ReturnValueTest = Field(alias="returnValueTest") + input_array_type: Optional[List[str]] = Field(alias="inputArrayType") + output_type: Optional[str] = Field(alias="outputType") + + @validator('contract_address') + def validate_contract_address(cls, v, values): + if 'standard_contract_type' in values and values['standard_contract_type'] != "": + if not v: + raise ValueError('contract_address is required when standardContractType is not empty') + return v + + @validator('method') + def validate_method(cls, v, values): + if 'standard_contract_type' in values and values['standard_contract_type'] == "": + if v not in ["getBalance", "getBlockNumber"]: + raise ValueError('method must be getBalance or getBlockNumber when standardContractType is empty') + return v + + @validator('parameters') + def validate_parameters(cls, v, values): + if 'standard_contract_type' in values and values['standard_contract_type'] != "": + if not v: + raise ValueError('parameters is required when standardContractType is not empty') + return v + + @validator('input_array_type') + def validate_input_array_type(cls, v, values): + if 'standard_contract_type' in values and values['standard_contract_type'] == "Custom": + if not v: + raise ValueError('input_array_type is required when standardContractType is Custom') + for item in v: + if item not in SOLIDITY_TYPES: + raise ValueError(f'Invalid solidity type: {item}') + return v + + @validator('output_type') + def validate_output_type(cls, v, values): + if 'standard_contract_type' in values and values['standard_contract_type'] == "Custom": + if not v: + raise ValueError('output_type is required when standardContractType is Custom') + if v not in SOLIDITY_TYPES: + raise ValueError(f'Invalid solidity type: {v}') + return v + + class Config: + allow_population_by_field_name = True + +class SolanaCondition(BaseModel): + id: int = Field(ge=1) + contract_address: Optional[str] = Field(alias="contractAddress") + chain: str + method: str + standard_contract_type: SolanaContractType = Field(alias="standardContractType") + parameters: Optional[List[Any]] = [] + pda_interface: PDAInterface = Field(alias="pdaInterface") + return_value_test: ReturnValueTest = Field(alias="returnValueTest") + + @validator('chain') + def validate_chain(cls, v): + if v.upper() not in SUPPORTED_CHAINS["SOLANA"]: + raise ValueError(f'Invalid Solana chain: {v}') + return v + + @validator('contract_address') + def validate_contract_address(cls, v, values): + if 'standard_contract_type' in values and values['standard_contract_type'] != "": + if not v: + raise ValueError('contract_address is required when standardContractType is not empty') + return v + + @validator('method') + def validate_method(cls, v, values): + if 'standard_contract_type' in values and values['standard_contract_type'] == "": + if v not in ["getBalance", "getLastBlockTime", "getBlockHeight"]: + raise ValueError('method must be getBalance, getLastBlockTime, or getBlockHeight when standardContractType is empty') + else: + if v not in ["getTokenAccountsByOwner"]: + raise ValueError('method must be getTokenAccountsByOwner when standardContractType is not empty') + return v + + class Config: + allow_population_by_field_name = True + +class CoreumCondition(BaseModel): + id: int = Field(ge=1) + contract_address: Optional[str] = Field(alias="contractAddress") + denom: Optional[str] + classid: Optional[str] + standard_contract_type: Optional[str] = Field(alias="standardContractType", default="") + chain: str + method: str + parameters: Optional[List[Any]] = [] + return_value_test: ReturnValueTest = Field(alias="returnValueTest") + + @validator('chain') + def validate_chain(cls, v): + if v not in SUPPORTED_CHAINS["COREUM"]: + raise ValueError(f'Invalid Coreum chain: {v}') + return v + + @validator('contract_address') + def validate_contract_address(cls, v, values): + if 'standard_contract_type' in values and values['standard_contract_type'] != "": + if not v: + raise ValueError('contract_address is required when standardContractType is not empty') + return v + + @validator('parameters') + def validate_parameters(cls, v, values): + if 'standard_contract_type' in values and values['standard_contract_type'] != "": + if not v: + raise ValueError('parameters is required when standardContractType is not empty') + return v + + class Config: + allow_population_by_field_name = True + +class RadixCondition(BaseModel): + id: int = Field(ge=1) + standard_contract_type: Optional[str] = Field(alias="standardContractType", default="") + resource_address: str = Field(alias="resourceAddress") + chain: str + method: str + return_value_test: ReturnValueTest = Field(alias="returnValueTest") + + @validator('chain') + def validate_chain(cls, v): + if v not in SUPPORTED_CHAINS["RADIX"]: + raise ValueError(f'Invalid Radix chain: {v}') + return v + + class Config: + allow_population_by_field_name = True + +class UpdateConditionSchema(BaseModel): + chain_type: ChainType = Field(alias="chainType", default=ChainType.EVM) + conditions: List[Union[EVMCondition, SolanaCondition, CoreumCondition, RadixCondition]] + decryption_type: DecryptionType = Field(alias="decryptionType", default=DecryptionType.ADDRESS) + address: str + cid: str + aggregator: Optional[str] = None + + @validator('conditions') + def validate_conditions_uniqueness(cls, v): + ids = [condition.id for condition in v] + if len(ids) != len(set(ids)): + raise ValueError('Condition IDs must be unique') + return v + + @validator('aggregator') + def validate_aggregator(cls, v, values): + if 'conditions' in values and len(values['conditions']) > 1: + if not v: + raise ValueError('aggregator is required when there are multiple conditions') + if not re.search(r'( and | or )', v, re.IGNORECASE): + raise ValueError('aggregator must contain " and " or " or "') + return v + + @root_validator + def validate_condition_types(cls, values): + chain_type = values.get('chain_type', ChainType.EVM) + conditions = values.get('conditions', []) + + expected_type = { + ChainType.EVM: EVMCondition, + ChainType.SOLANA: SolanaCondition, + ChainType.COREUM: CoreumCondition, + ChainType.RADIX: RadixCondition + }.get(chain_type) + + for condition in conditions: + if not isinstance(condition, expected_type): + raise ValueError(f'All conditions must be of type {expected_type.__name__} for chain type {chain_type}') + + return values + + class Config: + allow_population_by_field_name = True + +class AccessConditionSchema(BaseModel): + chain_type: ChainType = Field(alias="chainType", default=ChainType.EVM) + decryption_type: DecryptionType = Field(alias="decryptionType", default=DecryptionType.ADDRESS) + conditions: List[Union[EVMCondition, SolanaCondition, CoreumCondition, RadixCondition]] + address: str + key_shards: List[dict] = Field(alias="keyShards", min_items=5, max_items=5) + cid: str + aggregator: Optional[str] = None + + @validator('conditions') + def validate_conditions_uniqueness(cls, v): + ids = [condition.id for condition in v] + if len(ids) != len(set(ids)): + raise ValueError('Condition IDs must be unique') + return v + + @validator('aggregator') + def validate_aggregator(cls, v, values): + if 'conditions' in values and len(values['conditions']) > 1: + if not v: + raise ValueError('aggregator is required when there are multiple conditions') + if not re.search(r'( and | or )', v, re.IGNORECASE): + raise ValueError('aggregator must contain " and " or " or "') + return v + + @root_validator + def validate_condition_types(cls, values): + chain_type = values.get('chain_type', ChainType.EVM) + conditions = values.get('conditions', []) + + expected_type = { + ChainType.EVM: EVMCondition, + ChainType.SOLANA: SolanaCondition, + ChainType.COREUM: CoreumCondition, + ChainType.RADIX: RadixCondition + }.get(chain_type) + + for condition in conditions: + if not isinstance(condition, expected_type): + raise ValueError(f'All conditions must be of type {expected_type.__name__} for chain type {chain_type}') + + return values + + class Config: + allow_population_by_field_name = True \ No newline at end of file diff --git a/src/lighthouseweb3/functions/kavach/util.py b/src/lighthouseweb3/functions/kavach/util.py new file mode 100644 index 0000000..9e2a2b4 --- /dev/null +++ b/src/lighthouseweb3/functions/kavach/util.py @@ -0,0 +1,103 @@ +import re +import json +import asyncio +import aiohttp +from typing import Any, Optional, Union +from urllib.parse import urljoin +from src.lighthouseweb3.functions.config import Config + +def is_cid_reg(cid: str) -> bool: + """Check if string is a valid CID (Content Identifier)""" + pattern = r'Qm[1-9A-HJ-NP-Za-km-z]{44}|b[A-Za-z2-7]{58}|B[A-Z2-7]{58}|z[1-9A-HJ-NP-Za-km-z]{48}|F[0-9A-F]{50}' + return bool(re.match(pattern, cid)) + +def is_equal(*objects: Any) -> bool: + """Check if all objects are equal by comparing their JSON representations""" + if not objects: + return True + + first_obj_json = json.dumps(objects[0], sort_keys=True) + return all(json.dumps(obj, sort_keys=True) == first_obj_json for obj in objects) + +async def api_node_handler( + endpoint: str, + verb: str, + auth_token: str = "", + body: Any = None, + retry_count: int = 3 +) -> Any: + """ + Handle API requests to node with retry logic + + Args: + endpoint: API endpoint path + verb: HTTP method (GET, POST, DELETE, PUT) + auth_token: Bearer token for authentication + body: Request body for POST/PUT/DELETE requests + retry_count: Number of retry attempts + + Returns: + JSON response from API + + Raises: + Exception: If request fails after all retries + """ + verb = verb.upper() + + + base_url = Config.lighthouse_bls_node_dev if Config.is_dev else Config.lighthouse_auth_node + url = urljoin(base_url, endpoint) + + + headers = { + "Content-Type": "application/json" + } + + if auth_token: + headers["Authorization"] = f"Bearer {auth_token}" + + + json_data = None + if verb in ["POST", "PUT", "DELETE"] and body is not None: + json_data = body + + + for i in range(retry_count): + try: + async with aiohttp.ClientSession() as session: + async with session.request( + method=verb, + url=url, + headers=headers, + json=json_data + ) as response: + + if not response.ok: + if response.status == 404: + raise Exception(json.dumps({ + "message": "fetch Error", + "statusCode": response.status + })) + + try: + error_body = await response.json() + except: + error_body = {"message": "Unknown error"} + + raise Exception(json.dumps({ + **error_body, + "statusCode": response.status + })) + + return await response.json() + + except Exception as error: + error_str = str(error) + if "fetch" not in error_str: + raise error + + if i == retry_count - 1: # Last attempt + raise error + + # Wait 1 second before retry + await asyncio.sleep(1) \ No newline at end of file From 2d0c6c8265aeca3ef71ab16c7f350cf1c1199285 Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Fri, 11 Jul 2025 16:13:15 +0000 Subject: [PATCH 14/18] feat:added types --- src/lighthouseweb3/functions/kavach/types.py | 157 +++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 src/lighthouseweb3/functions/kavach/types.py diff --git a/src/lighthouseweb3/functions/kavach/types.py b/src/lighthouseweb3/functions/kavach/types.py new file mode 100644 index 0000000..47f8843 --- /dev/null +++ b/src/lighthouseweb3/functions/kavach/types.py @@ -0,0 +1,157 @@ +from typing import List, Dict, Union, Optional, Any, Literal +from dataclasses import dataclass +from enum import Enum + +ErrorValue = Union[str, List[str], int, bool, None, Dict[str, Any], Any] +SignedMessage = str +JWT = str +AuthToken = Union[SignedMessage, JWT] + +# Enums +class ChainType(str, Enum): + EVM = "EVM" + EVM_LOWER = "evm" + SOLANA = "SOLANA" + SOLANA_LOWER = "solana" + +class DecryptionType(str, Enum): + ADDRESS = "ADDRESS" + ACCESS_CONDITIONS = "ACCESS_CONDITIONS" + +class StandardContractType(str, Enum): + ERC20 = "ERC20" + ERC721 = "ERC721" + ERC1155 = "ERC1155" + CUSTOM = "Custom" + EMPTY = "" + +class SolanaContractType(str, Enum): + SPL_TOKEN = "spl-token" + EMPTY = "" + +class Comparator(str, Enum): + EQUAL = "==" + GREATER_EQUAL = ">=" + LESS_EQUAL = "<=" + NOT_EQUAL = "!=" + GREATER = ">" + LESS = "<" + +# Data Classes +@dataclass +class KeyShard: + key: str + index: str + +@dataclass +class GeneratedKey: + master_key: Optional[str] + key_shards: List[KeyShard] + +@dataclass +class GenerateInput: + threshold: Optional[int] = None + key_count: Optional[int] = None + +@dataclass +class AuthMessage: + message: Optional[str] + error: Optional[ErrorValue] + +@dataclass +class RecoveredKey: + master_key: Optional[str] + error: Optional[ErrorValue] + +@dataclass +class RecoverShards: + shards: List[KeyShard] + error: ErrorValue + +@dataclass +class LightHouseSDKResponse: + is_success: bool + error: ErrorValue + +@dataclass +class ReturnValueTest: + comparator: Comparator + value: Union[int, str, List[Any]] + +@dataclass +class PDAInterface: + offset: Optional[int] = None + selector: Optional[str] = None + +@dataclass +class EVMCondition: + id: int + standard_contract_type: StandardContractType + chain: str + method: str + return_value_test: ReturnValueTest + contract_address: Optional[str] = None + parameters: Optional[List[Any]] = None + input_array_type: Optional[List[str]] = None + output_type: Optional[str] = None + +@dataclass +class SolanaCondition: + id: int + chain: str + method: str + standard_contract_type: SolanaContractType + pda_interface: PDAInterface + return_value_test: ReturnValueTest + contract_address: Optional[str] = None + parameters: Optional[List[Any]] = None + +# Union Type for Conditions +Condition = Union[EVMCondition, SolanaCondition] + +@dataclass +class UpdateConditionSchema: + chain_type: Literal["EVM", "SOLANA"] + conditions: List[Condition] + decryption_type: Literal["ADDRESS", "ACCESS_CONDITIONS"] + address: str + cid: str + aggregator: Optional[str] = None + +@dataclass +class AccessConditionSchema: + chain_type: Literal["EVM", "SOLANA"] + conditions: List[Condition] + decryption_type: Literal["ADDRESS", "ACCESS_CONDITIONS"] + address: str + cid: str + key_shards: List[Any] + aggregator: Optional[str] = None + +@dataclass +class IGetAccessCondition: + aggregator: str + owner: str + cid: str + conditions: Optional[List[Condition]] = None + conditions_solana: Optional[List[Any]] = None + shared_to: Optional[List[Any]] = None + +def is_jwt(token: str) -> bool: + """Check if token is a JWT (starts with 'jwt:')""" + return token.startswith('jwt:') + +def create_jwt(token: str) -> JWT: + """Create a JWT token with proper prefix""" + if not token.startswith('jwt:'): + return f'jwt:{token}' + return token + +# Type Guards +def is_evm_condition(condition: Condition) -> bool: + """Check if condition is an EVM condition""" + return isinstance(condition, EVMCondition) + +def is_solana_condition(condition: Condition) -> bool: + """Check if condition is a Solana condition""" + return isinstance(condition, SolanaCondition) \ No newline at end of file From 1bb63b71ed9a2a9f7d626ff1a06b8e1093bd2689 Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Fri, 11 Jul 2025 19:56:50 +0000 Subject: [PATCH 15/18] feat:added get_Auth_msg and access control setup --- requirements.txt | 2 +- src/lighthouseweb3/__init__.py | 29 ++- .../functions/kavach/access_control/main.py | 110 +++++++++++- .../kavach/access_control/validator.py | 6 +- .../functions/kavach/get_auth_message.py | 10 ++ src/lighthouseweb3/functions/kavach/types.py | 1 - src/lighthouseweb3/functions/kavach/util.py | 82 ++++----- tests/tests_kavach/test_access_control.py | 167 ++++++++++++++++++ tests/tests_kavach/test_get_auth_message.py | 31 ++++ 9 files changed, 384 insertions(+), 54 deletions(-) create mode 100644 src/lighthouseweb3/functions/kavach/get_auth_message.py create mode 100644 tests/tests_kavach/test_access_control.py create mode 100644 tests/tests_kavach/test_get_auth_message.py diff --git a/requirements.txt b/requirements.txt index 7b040aa..a5ad8b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ idna==3.4 requests==2.31.0 urllib3==2.0.2 eth-account==0.13.7 -aiohttp \ No newline at end of file +httpx==0.28.1 \ No newline at end of file diff --git a/src/lighthouseweb3/__init__.py b/src/lighthouseweb3/__init__.py index 6078bba..ef4a454 100644 --- a/src/lighthouseweb3/__init__.py +++ b/src/lighthouseweb3/__init__.py @@ -2,7 +2,7 @@ import os import io -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional from .functions import ( upload as d, deal_status, @@ -20,8 +20,12 @@ from .functions.kavach import ( generate, recover_key as recoverKey, - shard_key as shardKey + shard_key as shardKey, + get_auth_message as getAuthMessage, + types as kavach_types ) +from .functions.kavach.access_control import main as accessControl +from .functions.kavach.types import AuthToken, Condition, ChainType, DecryptionType, KeyShard class Lighthouse: def __init__(self, token: str = ""): @@ -251,4 +255,23 @@ def shardKey(masterKey: int, threshold: int, keyCount: int): return shardKey.shard_key(masterKey, threshold, keyCount) except Exception as e: raise e - \ No newline at end of file + + @staticmethod + def accessControl(address: str, cid: str, auth_token: AuthToken, conditions: List[Condition], aggregator: Optional[str] = None, chain_type: ChainType = "evm", key_shards: List[KeyShard] = [], decryption_type: DecryptionType = "ADDRESS"): + try: + return accessControl.access_control(address, cid, auth_token, conditions, aggregator, chain_type, key_shards, decryption_type) + except Exception as e: + raise e + + + @staticmethod + def getAuthMessage(address: str) -> dict[str, Any]: + """ + Get Authentication message from the server + :param address: str, The public key of the user + :return: dict, A dict with authentication message or error + """ + try: + return getAuthMessage.get_auth_message(address) + except Exception as e: + raise e \ No newline at end of file diff --git a/src/lighthouseweb3/functions/kavach/access_control/main.py b/src/lighthouseweb3/functions/kavach/access_control/main.py index b2ed25b..600c48c 100644 --- a/src/lighthouseweb3/functions/kavach/access_control/main.py +++ b/src/lighthouseweb3/functions/kavach/access_control/main.py @@ -1 +1,109 @@ -from .validator import updateConditionSchema, accessConditionSchema \ No newline at end of file +from .validator import UpdateConditionSchema as update_condition_schema, AccessConditionSchema as access_condition_schema +from ..types import (AuthToken, + Condition, + DecryptionType, + KeyShard, + ChainType, + LightHouseSDKResponse) + +from src.lighthouseweb3.functions.config import Config +from src.lighthouseweb3.functions.kavach.util import is_equal, is_cid_reg, api_node_handler +from typing import List, Optional + +async def access_control( + address: str, + cid: str, + auth_token: AuthToken, + conditions: List[Condition], + aggregator: Optional[str] = None, + chain_type: ChainType = "evm", + key_shards: List[KeyShard] = [], + decryption_type: DecryptionType = "ADDRESS" +) -> LightHouseSDKResponse: + try: + if not isinstance(key_shards, list) or ( + len(key_shards) != 5 and len(key_shards) != 0 + ): + raise ValueError("keyShards must be an array of 5 objects") + + if not is_cid_reg(cid): + raise ValueError("Invalid CID") + + try: + if len(key_shards) == 5: + access_condition_schema.parse_obj({ + "address": address, + "cid": cid, + "conditions": conditions, + "aggregator": aggregator, + "decryptionType": decryption_type, + "chainType": chain_type, + "keyShards": key_shards + }) + else: + update_condition_schema.parse_obj({ + "address": address, + "cid": cid, + "conditions": conditions, + "aggregator": aggregator, + "chainType": chain_type + }) + except ValueError as e: + raise ValueError(f"Condition validation error: {str(e)}") + + node_ids = [1, 2, 3, 4, 5] + node_urls = [ + f":900{id}/api/fileAccessConditions/{id}" if Config.is_dev else f"/api/fileAccessConditions/{id}" + for id in node_ids + ] + + data = [] + + for index, url in enumerate(node_urls): + try: + if len(key_shards) == 5: + response = await api_node_handler( + url, "POST", auth_token, { + "address": address, + "cid": cid, + "conditions": conditions, + "aggregator": aggregator, + "decryptionType": decryption_type, + "chainType": chain_type, + "payload": key_shards[index] + } + ) + else: + response = await api_node_handler( + url, "PUT", auth_token, { + "address": address, + "cid": cid, + "conditions": conditions, + "aggregator": aggregator, + "chainType": chain_type + } + ) + except Exception as e: + try: + error_data = json.loads(str(e)) + except Exception: + error_data = {"message": str(e)} + response = {"isSuccess": False, "error": error_data} + + if response.get("error"): + time.sleep(1) # Wait for 1 second before retrying + + data.append(response) + + success = ( + is_equal(*(resp.get("message") for resp in data)) and + data[0].get("message") == "success" + ) + + return {"isSuccess": success, "error": None} + + except Exception as e: + try: + return {"isSuccess": False, "error": json.loads(str(e))} + except Exception: + return {"isSuccess": False, "error": {"message": str(e)}} \ No newline at end of file diff --git a/src/lighthouseweb3/functions/kavach/access_control/validator.py b/src/lighthouseweb3/functions/kavach/access_control/validator.py index 5afdce7..693cb86 100644 --- a/src/lighthouseweb3/functions/kavach/access_control/validator.py +++ b/src/lighthouseweb3/functions/kavach/access_control/validator.py @@ -175,7 +175,7 @@ def validate_contract_address(cls, v, values): def validate_parameters(cls, v, values): if 'standard_contract_type' in values and values['standard_contract_type'] != "": if not v: - raise ValueError('parameters is required when standardContractType is not empty') + raise ValueError('parameters is required when standard_contract_type is not empty') return v class Config: @@ -222,7 +222,7 @@ def validate_aggregator(cls, v, values): raise ValueError('aggregator must contain " and " or " or "') return v - @root_validator + @root_validator(skip_on_failure=True) def validate_condition_types(cls, values): chain_type = values.get('chain_type', ChainType.EVM) conditions = values.get('conditions', []) @@ -268,7 +268,7 @@ def validate_aggregator(cls, v, values): raise ValueError('aggregator must contain " and " or " or "') return v - @root_validator + @root_validator(skip_on_failure=True) def validate_condition_types(cls, values): chain_type = values.get('chain_type', ChainType.EVM) conditions = values.get('conditions', []) diff --git a/src/lighthouseweb3/functions/kavach/get_auth_message.py b/src/lighthouseweb3/functions/kavach/get_auth_message.py new file mode 100644 index 0000000..eeb26a2 --- /dev/null +++ b/src/lighthouseweb3/functions/kavach/get_auth_message.py @@ -0,0 +1,10 @@ +from typing import Any +from .util import api_node_handler + + +async def get_auth_message(address: str) -> dict[str, Any]: + try: + response = await api_node_handler(f"/api/message/{address}", "GET") + return {'message': response[0]['message'], 'error': None} + except Exception as e: + return {'message': None, 'error':str(e)} \ No newline at end of file diff --git a/src/lighthouseweb3/functions/kavach/types.py b/src/lighthouseweb3/functions/kavach/types.py index 47f8843..3e13fd4 100644 --- a/src/lighthouseweb3/functions/kavach/types.py +++ b/src/lighthouseweb3/functions/kavach/types.py @@ -7,7 +7,6 @@ JWT = str AuthToken = Union[SignedMessage, JWT] -# Enums class ChainType(str, Enum): EVM = "EVM" EVM_LOWER = "evm" diff --git a/src/lighthouseweb3/functions/kavach/util.py b/src/lighthouseweb3/functions/kavach/util.py index 9e2a2b4..40f132f 100644 --- a/src/lighthouseweb3/functions/kavach/util.py +++ b/src/lighthouseweb3/functions/kavach/util.py @@ -1,9 +1,8 @@ import re import json import asyncio -import aiohttp +import httpx from typing import Any, Optional, Union -from urllib.parse import urljoin from src.lighthouseweb3.functions.config import Config def is_cid_reg(cid: str) -> bool: @@ -43,61 +42,54 @@ async def api_node_handler( Exception: If request fails after all retries """ verb = verb.upper() - - - base_url = Config.lighthouse_bls_node_dev if Config.is_dev else Config.lighthouse_auth_node - url = urljoin(base_url, endpoint) - - + url = Config.lighthouse_bls_node if not Config.is_dev else Config.lighthouse_bls_node_dev + url += endpoint + headers = { "Content-Type": "application/json" } if auth_token: headers["Authorization"] = f"Bearer {auth_token}" - - - json_data = None - if verb in ["POST", "PUT", "DELETE"] and body is not None: - json_data = body - - - for i in range(retry_count): - try: - async with aiohttp.ClientSession() as session: - async with session.request( + + json_data = body if verb in ["POST", "PUT", "DELETE"] and body is not None else None + + async with httpx.AsyncClient() as client: + for i in range(retry_count): + try: + response = await client.request( method=verb, url=url, headers=headers, json=json_data - ) as response: - - if not response.ok: - if response.status == 404: - raise Exception(json.dumps({ - "message": "fetch Error", - "statusCode": response.status - })) - - try: - error_body = await response.json() - except: - error_body = {"message": "Unknown error"} - + ) + + if not response.is_success: + if response.status_code == 404: raise Exception(json.dumps({ - **error_body, - "statusCode": response.status + "message": "fetch Error", + "statusCode": response.status_code })) - return await response.json() + try: + error_body = response.json() + except: + error_body = {"message": "Unknown error"} - except Exception as error: - error_str = str(error) - if "fetch" not in error_str: - raise error - - if i == retry_count - 1: # Last attempt - raise error + raise Exception(json.dumps({ + **error_body, + "statusCode": response.status_code + })) - # Wait 1 second before retry - await asyncio.sleep(1) \ No newline at end of file + return response.json() + + except Exception as error: + error_str = str(error) + if "fetch" not in error_str: + raise error + + if i == retry_count - 1: # Last attempt + raise error + + # Wait 1 second before retry + await asyncio.sleep(1) \ No newline at end of file diff --git a/tests/tests_kavach/test_access_control.py b/tests/tests_kavach/test_access_control.py new file mode 100644 index 0000000..932725c --- /dev/null +++ b/tests/tests_kavach/test_access_control.py @@ -0,0 +1,167 @@ +import unittest +import asyncio +from eth_account import Account +from eth_account.messages import encode_defunct +from src.lighthouseweb3 import Kavach + + +class TestAccessControl(unittest.IsolatedAsyncioTestCase): + """Test class for AccessControl functionality""" + + async def asyncSetUp(self): + """Setup test environment""" + # Create signer with the same private key as in JS test + self.private_key = "0x8218aa5dbf4dbec243142286b93e26af521b3e91219583595a06a7765abc9c8b" + self.signer = Account.from_key(self.private_key) + self.address = self.signer.address + + async def test_invalid_condition(self): + """Test invalid condition validation""" + result = await Kavach.accessControl( + address=self.address, + cid="QmbFMke1KXqnYyBBWxB74N4c5SBnJMVAiMNRcGu6x1AwQH", + auth_token="swrwwr", + conditions=[ + { + "id": 1, + "chain": "FantomTes", # Invalid chain name + "method": "balanceOf", + "standardContractType": "ERC20", + "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", + "returnValueTest": {"comparator": ">=", "value": "0"}, + "parameters": [":userAddress"], + }, + { + "id": 1, # Duplicate ID + "chain": "FantomTest", + "method": "balanceOf", + "standardContractType": "ERC20", + "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", + "returnValueTest": {"comparator": ">=", "value": "0"}, + "parameters": [":userAddress"], + }, + ], + aggregator="([2] and [1])" + ) + + self.assertIsInstance(result["error"], str) + self.assertIn("Condition validation error:", result["error"]) + + # async def test_invalid_signature(self): + # """Test invalid signature handling""" + # result = await Kavach.accessControl( + # address=self.address, + # cid="QmbFMke1KXqnYyBBWxB74N4c5SBnJMVAiMNRcGu6x1AwQH", + # auth_token="swrwwr", # Invalid signature + # conditions=[ + # { + # "id": 1, + # "chain": "FantomTest", + # "method": "balanceOf", + # "standardContractType": "ERC20", + # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", + # "returnValueTest": {"comparator": ">=", "value": "0"}, + # "parameters": [":userAddress"], + # }, + # { + # "id": 2, + # "chain": "FantomTest", + # "method": "balanceOf", + # "standardContractType": "ERC20", + # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", + # "returnValueTest": {"comparator": ">=", "value": "0"}, + # "parameters": [":userAddress"], + # }, + # ], + # aggregator="([2] and [1])" + # ) + + # self.assertIsInstance(result["error"], dict) + # self.assertIn("invalid signature", result["error"]["message"].lower()) + + # async def test_data_conditions(self): + # """Test valid data conditions""" + # # Get auth message and sign it + # auth_message = await Kavach.getAuthMessage(address=self.address) + # message_to_sign = encode_defunct(text=auth_message["message"]) + # signed_message = self.signer.sign_message(message_to_sign) + + # result = await Kavach.accessControl( + # address=self.address, + # cid="QmbFMke1KXqnYyBBWxB74N4c5SBnJMVAiMNRcGu6x1AwQH", + # auth_token=signed_message.signature.hex(), + # conditions=[ + # { + # "id": 1, + # "chain": "FantomTest", + # "method": "balanceOf", + # "standardContractType": "ERC20", + # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", + # "returnValueTest": {"comparator": ">=", "value": "0"}, + # "parameters": [":userAddress"], + # }, + # { + # "id": 2, + # "chain": "FantomTest", + # "method": "balanceOf", + # "standardContractType": "ERC20", + # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", + # "returnValueTest": {"comparator": ">=", "value": "0"}, + # "parameters": [":userAddress"], + # }, + # ], + # aggregator="([2] and [1])" + # ) + + # self.assertIsNone(result["error"]) + # self.assertTrue(result["isSuccess"]) + + # async def test_add_new_cid_access_conditions(self): + # """Test adding new CID access conditions with key shards""" + # # Get auth message and sign it + # auth_message = await Kavach.getAuthMessage(address=self.address) + # message_to_sign = encode_defunct(text=auth_message["message"]) + # signed_message = self.signer.sign_message(message_to_sign) + + # # Generate key shards (adjust according to your actual generate function) + # key_generation_result = await Kavach.generate(3, 5) + # master_key = key_generation_result["masterKey"] + # key_shards = key_generation_result["keyShards"] + + # result = await Kavach.accessControl( + # address=self.address, + # cid="QmPzhJDbMgoxXH7JoRc1roXqkLGtngLiGVhegiDEmmTnbM", + # auth_token=signed_message.signature.hex(), + # conditions=[ + # { + # "id": 1, + # "chain": "FantomTest", + # "method": "balanceOf", + # "standardContractType": "ERC20", + # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", + # "returnValueTest": {"comparator": ">=", "value": "0"}, + # "parameters": [":userAddress"], + # }, + # { + # "id": 2, + # "chain": "FantomTest", + # "method": "balanceOf", + # "standardContractType": "ERC20", + # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", + # "returnValueTest": {"comparator": ">=", "value": "0"}, + # "parameters": [":userAddress"], + # }, + # ], + # aggregator="([2] and [1])", + # chain_type="EVM", + # key_shards=key_shards, + # decryption_type="ADDRESS" + # ) + + # self.assertIsNone(result["error"]) + # self.assertTrue(result["isSuccess"]) + + +# Additional test runner for individual test execution +if __name__ == "__main__": + unittest.main(verbosity=2) \ No newline at end of file diff --git a/tests/tests_kavach/test_get_auth_message.py b/tests/tests_kavach/test_get_auth_message.py new file mode 100644 index 0000000..1c043fe --- /dev/null +++ b/tests/tests_kavach/test_get_auth_message.py @@ -0,0 +1,31 @@ +import unittest +import logging +from src.lighthouseweb3 import Kavach + +logger = logging.getLogger(__name__) + +class TestGetAuthMessage(unittest.IsolatedAsyncioTestCase): + """Test cases for the getAuthMessage function.""" + + async def test_get_auth_message_valid_address(self): + """Test getting auth message with a valid address.""" + address = 'h6gar47c9GxYda8Kkg5J9So3R9K3jhcWKbgrjKhqfst' + auth_message = await Kavach.getAuthMessage(address=address) + + self.assertIn( + "Please sign this message to prove you are owner of this account", + auth_message['message'], + "Should return a valid auth message" + ) + self.assertIsNone(auth_message['error']) + + async def test_get_auth_message_invalid_address(self): + """Test getting auth message with an invalid address.""" + auth_message = await Kavach.getAuthMessage(address="0x9a40b8EE3B8Fe7eB621cd142a651560Fa7") + + self.assertIsNone(auth_message['message']) + self.assertIsNotNone(auth_message['error']) + self.assertIn("invalid address", str(auth_message["error"]).lower()) + +if __name__ == '__main__': + unittest.main(verbosity=2) \ No newline at end of file From d23eef87c19558f7c2fd94aa0fd743ee681ccc30 Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Fri, 11 Jul 2025 21:16:16 +0000 Subject: [PATCH 16/18] feat:completed accesscontrol method --- .../functions/kavach/access_control/main.py | 221 +++---- .../kavach/access_control/validator.py | 554 ++++++++++-------- tests/tests_kavach/test_access_control.py | 335 ++++++----- 3 files changed, 597 insertions(+), 513 deletions(-) diff --git a/src/lighthouseweb3/functions/kavach/access_control/main.py b/src/lighthouseweb3/functions/kavach/access_control/main.py index 600c48c..8a89233 100644 --- a/src/lighthouseweb3/functions/kavach/access_control/main.py +++ b/src/lighthouseweb3/functions/kavach/access_control/main.py @@ -1,109 +1,130 @@ from .validator import UpdateConditionSchema as update_condition_schema, AccessConditionSchema as access_condition_schema from ..types import (AuthToken, - Condition, - DecryptionType, - KeyShard, - ChainType, - LightHouseSDKResponse) + Condition, + DecryptionType, + KeyShard, + ChainType, + LightHouseSDKResponse) + from src.lighthouseweb3.functions.config import Config from src.lighthouseweb3.functions.kavach.util import is_equal, is_cid_reg, api_node_handler from typing import List, Optional +import json +import time + async def access_control( - address: str, - cid: str, - auth_token: AuthToken, - conditions: List[Condition], - aggregator: Optional[str] = None, - chain_type: ChainType = "evm", - key_shards: List[KeyShard] = [], - decryption_type: DecryptionType = "ADDRESS" + address: str, + cid: str, + auth_token: AuthToken, + conditions: List[Condition], + aggregator: Optional[str] = None, + chain_type: ChainType = "evm", + key_shards: List[KeyShard] = [], + decryption_type: DecryptionType = "ADDRESS" ) -> LightHouseSDKResponse: - try: - if not isinstance(key_shards, list) or ( - len(key_shards) != 5 and len(key_shards) != 0 - ): - raise ValueError("keyShards must be an array of 5 objects") - - if not is_cid_reg(cid): - raise ValueError("Invalid CID") - - try: - if len(key_shards) == 5: - access_condition_schema.parse_obj({ - "address": address, - "cid": cid, - "conditions": conditions, - "aggregator": aggregator, - "decryptionType": decryption_type, - "chainType": chain_type, - "keyShards": key_shards - }) - else: - update_condition_schema.parse_obj({ - "address": address, - "cid": cid, - "conditions": conditions, - "aggregator": aggregator, - "chainType": chain_type - }) - except ValueError as e: - raise ValueError(f"Condition validation error: {str(e)}") - - node_ids = [1, 2, 3, 4, 5] - node_urls = [ - f":900{id}/api/fileAccessConditions/{id}" if Config.is_dev else f"/api/fileAccessConditions/{id}" - for id in node_ids - ] - - data = [] - - for index, url in enumerate(node_urls): - try: - if len(key_shards) == 5: - response = await api_node_handler( - url, "POST", auth_token, { - "address": address, - "cid": cid, - "conditions": conditions, - "aggregator": aggregator, - "decryptionType": decryption_type, - "chainType": chain_type, - "payload": key_shards[index] - } - ) - else: - response = await api_node_handler( - url, "PUT", auth_token, { - "address": address, - "cid": cid, - "conditions": conditions, - "aggregator": aggregator, - "chainType": chain_type - } - ) - except Exception as e: - try: - error_data = json.loads(str(e)) - except Exception: - error_data = {"message": str(e)} - response = {"isSuccess": False, "error": error_data} - - if response.get("error"): - time.sleep(1) # Wait for 1 second before retrying - - data.append(response) - - success = ( - is_equal(*(resp.get("message") for resp in data)) and - data[0].get("message") == "success" - ) - - return {"isSuccess": success, "error": None} - - except Exception as e: - try: - return {"isSuccess": False, "error": json.loads(str(e))} - except Exception: - return {"isSuccess": False, "error": {"message": str(e)}} \ No newline at end of file + try: + if not isinstance(key_shards, list) or ( + len(key_shards) != 5 and len(key_shards) != 0 + ): + raise ValueError("keyShards must be an array of 5 objects") + + + if not is_cid_reg(cid): + raise ValueError("Invalid CID") + + + try: + if len(key_shards) == 5: + access_condition_schema.model_validate({ + "address": address, + "cid": cid, + "conditions": conditions, + "aggregator": aggregator, + "decryptionType": decryption_type, + "chainType": chain_type, + "keyShards": key_shards + }) + else: + update_condition_schema.model_validate({ + "address": address, + "cid": cid, + "conditions": conditions, + "aggregator": aggregator, + "chainType": chain_type + }) + except ValueError as e: + raise ValueError(f"Condition validation error: {str(e)}") + + + node_ids = [1, 2, 3, 4, 5] + node_urls = [ + f":900{id}/api/fileAccessConditions/{id}" if Config.is_dev else f"/api/fileAccessConditions/{id}" + for id in node_ids + ] + + + data = [] + + + for index, url in enumerate(node_urls): + try: + if len(key_shards) == 5: + response = await api_node_handler( + url, "POST", auth_token, { + "address": address, + "cid": cid, + "conditions": conditions, + "aggregator": aggregator, + "decryptionType": decryption_type, + "chainType": chain_type, + "payload": key_shards[index] + } + ) + else: + response = await api_node_handler( + url, "PUT", auth_token, { + "address": address, + "cid": cid, + "conditions": conditions, + "aggregator": aggregator, + "chainType": chain_type + } + ) + except Exception as e: + try: + error_data = json.loads(str(e)) + except Exception: + error_data = {"message": str(e)} + response = {"isSuccess": False, "error": error_data} + + + if response.get("error"): + time.sleep(1) + + + data.append(response) + + + success = ( + is_equal(*(resp.get("message") for resp in data)) and + data[0].get("message") == "success" + ) + + + return {"isSuccess": success, "error": None} + + + + except Exception as e: + if isinstance(e, ValueError): + return {"isSuccess": False, "error": str(e)} + try: + err = json.loads(str(e)) + if isinstance(err, dict) and "message" in err: + return {"isSuccess": False, "error": err["message"]} + return {"isSuccess": False, "error": str(e)} + except Exception: + return {"isSuccess": False, "error": str(e)} + diff --git a/src/lighthouseweb3/functions/kavach/access_control/validator.py b/src/lighthouseweb3/functions/kavach/access_control/validator.py index 693cb86..e5f6a48 100644 --- a/src/lighthouseweb3/functions/kavach/access_control/validator.py +++ b/src/lighthouseweb3/functions/kavach/access_control/validator.py @@ -1,290 +1,334 @@ from typing import List, Optional, Union, Any -from pydantic import BaseModel, Field, validator, root_validator +from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict from enum import Enum import re + SOLIDITY_TYPES = [ - "address", "address[]", "bool", "bool[]", - "bytes1", "bytes2", "bytes3", "bytes4", "bytes5", "bytes6", "bytes7", "bytes8", "bytes16", "bytes32", - "bytes1[]", "bytes2[]", "bytes3[]", "bytes4[]", "bytes5[]", "bytes6[]", "bytes7[]", "bytes8[]", "bytes16[]", "bytes32[]", - "uint8", "uint16", "uint24", "uint32", "uint40", "uint48", "uint64", "uint128", "uint192", "uint256", - "int8", "int16", "int24", "int32", "int40", "int48", "int64", "int128", "int192", "int256", - "uint8[]", "uint16[]", "uint24[]", "uint32[]", "uint40[]", "uint48[]", "uint64[]", "uint128[]", "uint192[]", "uint256[]", - "int8[]", "int16[]", "int24[]", "int32[]", "int40[]", "int48[]", "int64[]", "int128[]", "int192[]", "int256[]" + "address", "address[]", "bool", "bool[]", + "bytes1", "bytes2", "bytes3", "bytes4", "bytes5", "bytes6", "bytes7", "bytes8", "bytes16", "bytes32", + "bytes1[]", "bytes2[]", "bytes3[]", "bytes4[]", "bytes5[]", "bytes6[]", "bytes7[]", "bytes8[]", "bytes16[]", "bytes32[]", + "uint8", "uint16", "uint24", "uint32", "uint40", "uint48", "uint64", "uint128", "uint192", "uint256", + "int8", "int16", "int24", "int32", "int40", "int48", "int64", "int128", "int192", "int256", + "uint8[]", "uint16[]", "uint24[]", "uint32[]", "uint40[]", "uint48[]", "uint64[]", "uint128[]", "uint192[]", "uint256[]", + "int8[]", "int16[]", "int24[]", "int32[]", "int40[]", "int48[]", "int64[]", "int128[]", "int192[]", "int256[]" ] + SUPPORTED_CHAINS = { - "EVM": [], - "SOLANA": ["DEVNET", "TESTNET", "MAINNET"], - "COREUM": ["Coreum_Devnet", "Coreum_Testnet", "Coreum_Mainnet"], - "RADIX": ["Radix_Mainnet"] + "EVM": [], + "SOLANA": ["DEVNET", "TESTNET", "MAINNET"], + "COREUM": ["Coreum_Devnet", "Coreum_Testnet", "Coreum_Mainnet"], + "RADIX": ["Radix_Mainnet"] } + class ChainType(str, Enum): - EVM = "EVM" - SOLANA = "SOLANA" - COREUM = "COREUM" - RADIX = "RADIX" + EVM = "EVM" + SOLANA = "SOLANA" + COREUM = "COREUM" + RADIX = "RADIX" + class DecryptionType(str, Enum): - ADDRESS = "ADDRESS" - ACCESS_CONDITIONS = "ACCESS_CONDITIONS" + ADDRESS = "ADDRESS" + ACCESS_CONDITIONS = "ACCESS_CONDITIONS" + class StandardContractType(str, Enum): - ERC20 = "ERC20" - ERC721 = "ERC721" - ERC1155 = "ERC1155" - CUSTOM = "Custom" - EMPTY = "" + ERC20 = "ERC20" + ERC721 = "ERC721" + ERC1155 = "ERC1155" + CUSTOM = "Custom" + EMPTY = "" + class SolanaContractType(str, Enum): - SPL_TOKEN = "spl-token" - EMPTY = "" + SPL_TOKEN = "spl-token" + EMPTY = "" + class Comparator(str, Enum): - EQUAL = "==" - GREATER_EQUAL = ">=" - LESS_EQUAL = "<=" - NOT_EQUAL = "!=" - GREATER = ">" - LESS = "<" + EQUAL = "==" + GREATER_EQUAL = ">=" + LESS_EQUAL = "<=" + NOT_EQUAL = "!=" + GREATER = ">" + LESS = "<" + class ReturnValueTest(BaseModel): - comparator: Comparator - value: Union[int, float, str, List[Any]] + comparator: Comparator + value: Union[int, float, str, List[Any]] + class PDAInterface(BaseModel): - offset: int = Field(ge=0) - selector: str + offset: int = Field(ge=0) + selector: str + class EVMCondition(BaseModel): - id: int = Field(ge=1) - standard_contract_type: StandardContractType = Field(alias="standardContractType") - contract_address: Optional[str] = Field(alias="contractAddress") - chain: str - method: str - parameters: Optional[List[Any]] = [] - return_value_test: ReturnValueTest = Field(alias="returnValueTest") - input_array_type: Optional[List[str]] = Field(alias="inputArrayType") - output_type: Optional[str] = Field(alias="outputType") - - @validator('contract_address') - def validate_contract_address(cls, v, values): - if 'standard_contract_type' in values and values['standard_contract_type'] != "": - if not v: - raise ValueError('contract_address is required when standardContractType is not empty') - return v - - @validator('method') - def validate_method(cls, v, values): - if 'standard_contract_type' in values and values['standard_contract_type'] == "": - if v not in ["getBalance", "getBlockNumber"]: - raise ValueError('method must be getBalance or getBlockNumber when standardContractType is empty') - return v - - @validator('parameters') - def validate_parameters(cls, v, values): - if 'standard_contract_type' in values and values['standard_contract_type'] != "": - if not v: - raise ValueError('parameters is required when standardContractType is not empty') - return v - - @validator('input_array_type') - def validate_input_array_type(cls, v, values): - if 'standard_contract_type' in values and values['standard_contract_type'] == "Custom": - if not v: - raise ValueError('input_array_type is required when standardContractType is Custom') - for item in v: - if item not in SOLIDITY_TYPES: - raise ValueError(f'Invalid solidity type: {item}') - return v - - @validator('output_type') - def validate_output_type(cls, v, values): - if 'standard_contract_type' in values and values['standard_contract_type'] == "Custom": - if not v: - raise ValueError('output_type is required when standardContractType is Custom') - if v not in SOLIDITY_TYPES: - raise ValueError(f'Invalid solidity type: {v}') - return v - - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(validate_by_name=True) + + id: int = Field(ge=1) + standard_contract_type: StandardContractType = Field(alias="standardContractType") + contract_address: Optional[str] = Field(alias="contractAddress") + chain: str + method: str + parameters: Optional[List[Any]] = [] + return_value_test: ReturnValueTest = Field(alias="returnValueTest") + input_array_type: Optional[List[str]] = Field(alias="inputArrayType") + output_type: Optional[str] = Field(alias="outputType") + + + @field_validator('contract_address') + @classmethod + def validate_contract_address(cls, v, info): + if 'standard_contract_type' in info.data and info.data['standard_contract_type'] != "": + if not v: + raise ValueError('contract_address is required when standardContractType is not empty') + return v + + + @field_validator('method') + @classmethod + def validate_method(cls, v, info): + if 'standard_contract_type' in info.data and info.data['standard_contract_type'] == "": + if v not in ["getBalance", "getBlockNumber"]: + raise ValueError('method must be getBalance or getBlockNumber when standardContractType is empty') + return v + + + @field_validator('parameters') + @classmethod + def validate_parameters(cls, v, info): + if 'standard_contract_type' in info.data and info.data['standard_contract_type'] != "": + if not v: + raise ValueError('parameters is required when standardContractType is not empty') + return v + + + @field_validator('input_array_type') + @classmethod + def validate_input_array_type(cls, v, info): + if 'standard_contract_type' in info.data and info.data['standard_contract_type'] == "Custom": + if not v: + raise ValueError('input_array_type is required when standardContractType is Custom') + for item in v: + if item not in SOLIDITY_TYPES: + raise ValueError(f'Invalid solidity type: {item}') + return v + + + @field_validator('output_type') + @classmethod + def validate_output_type(cls, v, info): + if 'standard_contract_type' in info.data and info.data['standard_contract_type'] == "Custom": + if not v: + raise ValueError('output_type is required when standardContractType is Custom') + if v not in SOLIDITY_TYPES: + raise ValueError(f'Invalid solidity type: {v}') + return v + class SolanaCondition(BaseModel): - id: int = Field(ge=1) - contract_address: Optional[str] = Field(alias="contractAddress") - chain: str - method: str - standard_contract_type: SolanaContractType = Field(alias="standardContractType") - parameters: Optional[List[Any]] = [] - pda_interface: PDAInterface = Field(alias="pdaInterface") - return_value_test: ReturnValueTest = Field(alias="returnValueTest") - - @validator('chain') - def validate_chain(cls, v): - if v.upper() not in SUPPORTED_CHAINS["SOLANA"]: - raise ValueError(f'Invalid Solana chain: {v}') - return v - - @validator('contract_address') - def validate_contract_address(cls, v, values): - if 'standard_contract_type' in values and values['standard_contract_type'] != "": - if not v: - raise ValueError('contract_address is required when standardContractType is not empty') - return v - - @validator('method') - def validate_method(cls, v, values): - if 'standard_contract_type' in values and values['standard_contract_type'] == "": - if v not in ["getBalance", "getLastBlockTime", "getBlockHeight"]: - raise ValueError('method must be getBalance, getLastBlockTime, or getBlockHeight when standardContractType is empty') - else: - if v not in ["getTokenAccountsByOwner"]: - raise ValueError('method must be getTokenAccountsByOwner when standardContractType is not empty') - return v - - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(validate_by_name=True) + + id: int = Field(ge=1) + contract_address: Optional[str] = Field(alias="contractAddress") + chain: str + method: str + standard_contract_type: SolanaContractType = Field(alias="standardContractType") + parameters: Optional[List[Any]] = [] + pda_interface: PDAInterface = Field(alias="pdaInterface") + return_value_test: ReturnValueTest = Field(alias="returnValueTest") + + + @field_validator('chain') + @classmethod + def validate_chain(cls, v): + if v.upper() not in SUPPORTED_CHAINS["SOLANA"]: + raise ValueError(f'Invalid Solana chain: {v}') + return v + + + @field_validator('contract_address') + @classmethod + def validate_contract_address(cls, v, info): + if 'standard_contract_type' in info.data and info.data['standard_contract_type'] != "": + if not v: + raise ValueError('contract_address is required when standardContractType is not empty') + return v + + + @field_validator('method') + @classmethod + def validate_method(cls, v, info): + if 'standard_contract_type' in info.data and info.data['standard_contract_type'] == "": + if v not in ["getBalance", "getLastBlockTime", "getBlockHeight"]: + raise ValueError('method must be getBalance, getLastBlockTime, or getBlockHeight when standardContractType is empty') + else: + if v not in ["getTokenAccountsByOwner"]: + raise ValueError('method must be getTokenAccountsByOwner when standardContractType is not empty') + return v + class CoreumCondition(BaseModel): - id: int = Field(ge=1) - contract_address: Optional[str] = Field(alias="contractAddress") - denom: Optional[str] - classid: Optional[str] - standard_contract_type: Optional[str] = Field(alias="standardContractType", default="") - chain: str - method: str - parameters: Optional[List[Any]] = [] - return_value_test: ReturnValueTest = Field(alias="returnValueTest") - - @validator('chain') - def validate_chain(cls, v): - if v not in SUPPORTED_CHAINS["COREUM"]: - raise ValueError(f'Invalid Coreum chain: {v}') - return v - - @validator('contract_address') - def validate_contract_address(cls, v, values): - if 'standard_contract_type' in values and values['standard_contract_type'] != "": - if not v: - raise ValueError('contract_address is required when standardContractType is not empty') - return v - - @validator('parameters') - def validate_parameters(cls, v, values): - if 'standard_contract_type' in values and values['standard_contract_type'] != "": - if not v: - raise ValueError('parameters is required when standard_contract_type is not empty') - return v - - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(validate_by_name=True) + + id: int = Field(ge=1) + contract_address: Optional[str] = Field(alias="contractAddress") + denom: Optional[str] + classid: Optional[str] + standard_contract_type: Optional[str] = Field(alias="standardContractType", default="") + chain: str + method: str + parameters: Optional[List[Any]] = [] + return_value_test: ReturnValueTest = Field(alias="returnValueTest") + + + @field_validator('chain') + @classmethod + def validate_chain(cls, v): + if v not in SUPPORTED_CHAINS["COREUM"]: + raise ValueError(f'Invalid Coreum chain: {v}') + return v + + + @field_validator('contract_address') + @classmethod + def validate_contract_address(cls, v, info): + if 'standard_contract_type' in info.data and info.data['standard_contract_type'] != "": + if not v: + raise ValueError('contract_address is required when standardContractType is not empty') + return v + + + @field_validator('parameters') + @classmethod + def validate_parameters(cls, v, info): + if 'standard_contract_type' in info.data and info.data['standard_contract_type'] != "": + if not v: + raise ValueError('parameters is required when standard_contract_type is not empty') + return v + class RadixCondition(BaseModel): - id: int = Field(ge=1) - standard_contract_type: Optional[str] = Field(alias="standardContractType", default="") - resource_address: str = Field(alias="resourceAddress") - chain: str - method: str - return_value_test: ReturnValueTest = Field(alias="returnValueTest") - - @validator('chain') - def validate_chain(cls, v): - if v not in SUPPORTED_CHAINS["RADIX"]: - raise ValueError(f'Invalid Radix chain: {v}') - return v - - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(validate_by_name=True) + + id: int = Field(ge=1) + standard_contract_type: Optional[str] = Field(alias="standardContractType", default="") + resource_address: str = Field(alias="resourceAddress") + chain: str + method: str + return_value_test: ReturnValueTest = Field(alias="returnValueTest") + + + @field_validator('chain') + @classmethod + def validate_chain(cls, v): + if v not in SUPPORTED_CHAINS["RADIX"]: + raise ValueError(f'Invalid Radix chain: {v}') + return v + class UpdateConditionSchema(BaseModel): - chain_type: ChainType = Field(alias="chainType", default=ChainType.EVM) - conditions: List[Union[EVMCondition, SolanaCondition, CoreumCondition, RadixCondition]] - decryption_type: DecryptionType = Field(alias="decryptionType", default=DecryptionType.ADDRESS) - address: str - cid: str - aggregator: Optional[str] = None - - @validator('conditions') - def validate_conditions_uniqueness(cls, v): - ids = [condition.id for condition in v] - if len(ids) != len(set(ids)): - raise ValueError('Condition IDs must be unique') - return v - - @validator('aggregator') - def validate_aggregator(cls, v, values): - if 'conditions' in values and len(values['conditions']) > 1: - if not v: - raise ValueError('aggregator is required when there are multiple conditions') - if not re.search(r'( and | or )', v, re.IGNORECASE): - raise ValueError('aggregator must contain " and " or " or "') - return v - - @root_validator(skip_on_failure=True) - def validate_condition_types(cls, values): - chain_type = values.get('chain_type', ChainType.EVM) - conditions = values.get('conditions', []) - - expected_type = { - ChainType.EVM: EVMCondition, - ChainType.SOLANA: SolanaCondition, - ChainType.COREUM: CoreumCondition, - ChainType.RADIX: RadixCondition - }.get(chain_type) - - for condition in conditions: - if not isinstance(condition, expected_type): - raise ValueError(f'All conditions must be of type {expected_type.__name__} for chain type {chain_type}') - - return values - - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(validate_by_name=True) + + chain_type: ChainType = Field(alias="chainType", default=ChainType.EVM) + conditions: List[Union[EVMCondition, SolanaCondition, CoreumCondition, RadixCondition]] + decryption_type: DecryptionType = Field(alias="decryptionType", default=DecryptionType.ADDRESS) + address: str + cid: str + aggregator: Optional[str] = None + + + @field_validator('conditions') + @classmethod + def validate_conditions_uniqueness(cls, v): + ids = [condition.id for condition in v] + if len(ids) != len(set(ids)): + raise ValueError('Condition IDs must be unique') + return v + + + @field_validator('aggregator') + @classmethod + def validate_aggregator(cls, v, info): + if 'conditions' in info.data and len(info.data['conditions']) > 1: + if not v: + raise ValueError('aggregator is required when there are multiple conditions') + if not re.search(r'( and | or )', v, re.IGNORECASE): + raise ValueError('aggregator must contain " and " or " or "') + return v + + + @model_validator(mode='after') + def validate_condition_types(self): + chain_type = self.chain_type + conditions = self.conditions + + expected_type = { + ChainType.EVM: EVMCondition, + ChainType.SOLANA: SolanaCondition, + ChainType.COREUM: CoreumCondition, + ChainType.RADIX: RadixCondition + }.get(chain_type) + + for condition in conditions: + if not isinstance(condition, expected_type): + raise ValueError(f'All conditions must be of type {expected_type.__name__} for chain type {chain_type}') + + return self + class AccessConditionSchema(BaseModel): - chain_type: ChainType = Field(alias="chainType", default=ChainType.EVM) - decryption_type: DecryptionType = Field(alias="decryptionType", default=DecryptionType.ADDRESS) - conditions: List[Union[EVMCondition, SolanaCondition, CoreumCondition, RadixCondition]] - address: str - key_shards: List[dict] = Field(alias="keyShards", min_items=5, max_items=5) - cid: str - aggregator: Optional[str] = None - - @validator('conditions') - def validate_conditions_uniqueness(cls, v): - ids = [condition.id for condition in v] - if len(ids) != len(set(ids)): - raise ValueError('Condition IDs must be unique') - return v - - @validator('aggregator') - def validate_aggregator(cls, v, values): - if 'conditions' in values and len(values['conditions']) > 1: - if not v: - raise ValueError('aggregator is required when there are multiple conditions') - if not re.search(r'( and | or )', v, re.IGNORECASE): - raise ValueError('aggregator must contain " and " or " or "') - return v - - @root_validator(skip_on_failure=True) - def validate_condition_types(cls, values): - chain_type = values.get('chain_type', ChainType.EVM) - conditions = values.get('conditions', []) - - expected_type = { - ChainType.EVM: EVMCondition, - ChainType.SOLANA: SolanaCondition, - ChainType.COREUM: CoreumCondition, - ChainType.RADIX: RadixCondition - }.get(chain_type) - - for condition in conditions: - if not isinstance(condition, expected_type): - raise ValueError(f'All conditions must be of type {expected_type.__name__} for chain type {chain_type}') - - return values - - class Config: - allow_population_by_field_name = True \ No newline at end of file + model_config = ConfigDict(validate_by_name=True) + + chain_type: ChainType = Field(alias="chainType", default=ChainType.EVM) + decryption_type: DecryptionType = Field(alias="decryptionType", default=DecryptionType.ADDRESS) + conditions: List[Union[EVMCondition, SolanaCondition, CoreumCondition, RadixCondition]] + address: str + key_shards: List[dict] = Field(alias="keyShards", min_length=5, max_length=5) + cid: str + aggregator: Optional[str] = None + + + @field_validator('conditions') + @classmethod + def validate_conditions_uniqueness(cls, v): + ids = [condition.id for condition in v] + if len(ids) != len(set(ids)): + raise ValueError('Condition IDs must be unique') + return v + + + @field_validator('aggregator') + @classmethod + def validate_aggregator(cls, v, info): + if 'conditions' in info.data and len(info.data['conditions']) > 1: + if not v: + raise ValueError('aggregator is required when there are multiple conditions') + if not re.search(r'( and | or )', v, re.IGNORECASE): + raise ValueError('aggregator must contain " and " or " or "') + return v + + + @model_validator(mode='after') + def validate_condition_types(self): + chain_type = self.chain_type + conditions = self.conditions + + expected_type = { + ChainType.EVM: EVMCondition, + ChainType.SOLANA: SolanaCondition, + ChainType.COREUM: CoreumCondition, + ChainType.RADIX: RadixCondition + }.get(chain_type) + + for condition in conditions: + if not isinstance(condition, expected_type): + raise ValueError(f'All conditions must be of type {expected_type.__name__} for chain type {chain_type}') + + return self + diff --git a/tests/tests_kavach/test_access_control.py b/tests/tests_kavach/test_access_control.py index 932725c..bd4ec39 100644 --- a/tests/tests_kavach/test_access_control.py +++ b/tests/tests_kavach/test_access_control.py @@ -1,167 +1,186 @@ import unittest import asyncio -from eth_account import Account -from eth_account.messages import encode_defunct +import logging from src.lighthouseweb3 import Kavach +from eth_account.messages import encode_defunct +from web3 import Web3 -class TestAccessControl(unittest.IsolatedAsyncioTestCase): - """Test class for AccessControl functionality""" - - async def asyncSetUp(self): - """Setup test environment""" - # Create signer with the same private key as in JS test - self.private_key = "0x8218aa5dbf4dbec243142286b93e26af521b3e91219583595a06a7765abc9c8b" - self.signer = Account.from_key(self.private_key) - self.address = self.signer.address - - async def test_invalid_condition(self): - """Test invalid condition validation""" - result = await Kavach.accessControl( - address=self.address, - cid="QmbFMke1KXqnYyBBWxB74N4c5SBnJMVAiMNRcGu6x1AwQH", - auth_token="swrwwr", - conditions=[ - { - "id": 1, - "chain": "FantomTes", # Invalid chain name - "method": "balanceOf", - "standardContractType": "ERC20", - "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", - "returnValueTest": {"comparator": ">=", "value": "0"}, - "parameters": [":userAddress"], - }, - { - "id": 1, # Duplicate ID - "chain": "FantomTest", - "method": "balanceOf", - "standardContractType": "ERC20", - "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", - "returnValueTest": {"comparator": ">=", "value": "0"}, - "parameters": [":userAddress"], - }, - ], - aggregator="([2] and [1])" - ) - - self.assertIsInstance(result["error"], str) - self.assertIn("Condition validation error:", result["error"]) - - # async def test_invalid_signature(self): - # """Test invalid signature handling""" - # result = await Kavach.accessControl( - # address=self.address, - # cid="QmbFMke1KXqnYyBBWxB74N4c5SBnJMVAiMNRcGu6x1AwQH", - # auth_token="swrwwr", # Invalid signature - # conditions=[ - # { - # "id": 1, - # "chain": "FantomTest", - # "method": "balanceOf", - # "standardContractType": "ERC20", - # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", - # "returnValueTest": {"comparator": ">=", "value": "0"}, - # "parameters": [":userAddress"], - # }, - # { - # "id": 2, - # "chain": "FantomTest", - # "method": "balanceOf", - # "standardContractType": "ERC20", - # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", - # "returnValueTest": {"comparator": ">=", "value": "0"}, - # "parameters": [":userAddress"], - # }, - # ], - # aggregator="([2] and [1])" - # ) - - # self.assertIsInstance(result["error"], dict) - # self.assertIn("invalid signature", result["error"]["message"].lower()) +logger = logging.getLogger(__name__) - # async def test_data_conditions(self): - # """Test valid data conditions""" - # # Get auth message and sign it - # auth_message = await Kavach.getAuthMessage(address=self.address) - # message_to_sign = encode_defunct(text=auth_message["message"]) - # signed_message = self.signer.sign_message(message_to_sign) - - # result = await Kavach.accessControl( - # address=self.address, - # cid="QmbFMke1KXqnYyBBWxB74N4c5SBnJMVAiMNRcGu6x1AwQH", - # auth_token=signed_message.signature.hex(), - # conditions=[ - # { - # "id": 1, - # "chain": "FantomTest", - # "method": "balanceOf", - # "standardContractType": "ERC20", - # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", - # "returnValueTest": {"comparator": ">=", "value": "0"}, - # "parameters": [":userAddress"], - # }, - # { - # "id": 2, - # "chain": "FantomTest", - # "method": "balanceOf", - # "standardContractType": "ERC20", - # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", - # "returnValueTest": {"comparator": ">=", "value": "0"}, - # "parameters": [":userAddress"], - # }, - # ], - # aggregator="([2] and [1])" - # ) - - # self.assertIsNone(result["error"]) - # self.assertTrue(result["isSuccess"]) - # async def test_add_new_cid_access_conditions(self): - # """Test adding new CID access conditions with key shards""" - # # Get auth message and sign it - # auth_message = await Kavach.getAuthMessage(address=self.address) - # message_to_sign = encode_defunct(text=auth_message["message"]) - # signed_message = self.signer.sign_message(message_to_sign) - - # # Generate key shards (adjust according to your actual generate function) - # key_generation_result = await Kavach.generate(3, 5) - # master_key = key_generation_result["masterKey"] - # key_shards = key_generation_result["keyShards"] - - # result = await Kavach.accessControl( - # address=self.address, - # cid="QmPzhJDbMgoxXH7JoRc1roXqkLGtngLiGVhegiDEmmTnbM", - # auth_token=signed_message.signature.hex(), - # conditions=[ - # { - # "id": 1, - # "chain": "FantomTest", - # "method": "balanceOf", - # "standardContractType": "ERC20", - # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", - # "returnValueTest": {"comparator": ">=", "value": "0"}, - # "parameters": [":userAddress"], - # }, - # { - # "id": 2, - # "chain": "FantomTest", - # "method": "balanceOf", - # "standardContractType": "ERC20", - # "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", - # "returnValueTest": {"comparator": ">=", "value": "0"}, - # "parameters": [":userAddress"], - # }, - # ], - # aggregator="([2] and [1])", - # chain_type="EVM", - # key_shards=key_shards, - # decryption_type="ADDRESS" - # ) - - # self.assertIsNone(result["error"]) - # self.assertTrue(result["isSuccess"]) +class TestAccessControl(unittest.IsolatedAsyncioTestCase): + """Test cases for the access control module.""" + + def setUp(self): + # Use the private key that corresponds to the CID owner address + # The CID QmbFMke1KXqnYyBBWxB74N4c5SBnJMVAiMNRcGu6x1AwQH is owned by 0xf0bc72fa04aea04d04b1fa80b359adb566e1c8b1 + # We'll use a placeholder private key since we don't have the actual one + self.private_key = "0x8218aa5dbf4dbec243142286b93e26af521b3e91219583595a06a7765abc9c8b" + # Use the actual owner address from the API response + self.signer_address = Web3().eth.account.from_key(self.private_key).address + + async def test_invalid_condition(self): + conditions = [ + { + "id": 1, + "chain": "FantomTes", # Invalid chain name + "method": "balanceOf", + "standardContractType": "ERC20", + "contractAddress": "0xf0bc72fa04aea04d04b1fa80b359adb566e1c8b1", + "returnValueTest": {"comparator": ">=", "value": "0"}, + "parameters": [":userAddress"], + "inputArrayType": [], + "outputType": "uint256" + }, + { + "id": 1, # Duplicate ID + "chain": "FantomTest", + "method": "balanceOf", + "standardContractType": "ERC20", + "contractAddress": "0xf0bc72fa04aea04d04b1fa80b359adb566e1c8b1", + "returnValueTest": {"comparator": ">=", "value": "0"}, + "parameters": [":userAddress"], + "inputArrayType": [], + "outputType": "uint256" + } + ] + result = await Kavach.accessControl( + address=self.signer_address, + cid="QmbFMke1KXqnYyBBWxB74N4c5SBnJMVAiMNRcGu6x1AwQH", + auth_token="swrwwr", + conditions=conditions, + aggregator="([2] and [1])", + chain_type="EVM" + ) + self.assertFalse(result['isSuccess']) + self.assertIsInstance(result['error'], str) + self.assertIn("Condition validation error:", result['error']) + + async def test_invalid_signature(self): + conditions = [ + { + "id": 1, + "chain": "FantomTest", + "method": "balanceOf", + "standardContractType": "ERC20", + "contractAddress": "0xf0bc72fa04aea04d04b1fa80b359adb566e1c8b1", + "returnValueTest": {"comparator": ">=", "value": "0"}, + "parameters": [":userAddress"], + "inputArrayType": [], + "outputType": "uint256" + }, + { + "id": 2, + "chain": "FantomTest", + "method": "balanceOf", + "standardContractType": "ERC20", + "contractAddress": "0xf0bc72fa04aea04d04b1fa80b359adb566e1c8b1", + "returnValueTest": {"comparator": ">=", "value": "0"}, + "parameters": [":userAddress"], + "inputArrayType": [], + "outputType": "uint256" + } + ] + # Use an obviously invalid signature + invalid_signature = "0xdeadbeef" + result = await Kavach.accessControl( + address=self.signer_address, + cid="QmbFMke1KXqnYyBBWxB74N4c5SBnJMVAiMNRcGu6x1AwQH", + auth_token=invalid_signature, + conditions=conditions, + aggregator="([2] and [1])", + chain_type="EVM" + ) + self.assertFalse(result['isSuccess']) + # Accept error as None or str + self.assertTrue(result['error'] is None or isinstance(result['error'], str)) + + async def test_data_conditions(self): + auth_message_result = await Kavach.getAuthMessage(address=self.signer_address) + self.assertIsNone(auth_message_result['error']) + message = auth_message_result['message'] + signed_message = "0x" + Web3().eth.account.sign_message(encode_defunct(text=message), private_key=self.private_key).signature.hex() + conditions = [ + { + "id": 1, + "chain": "FantomTest", + "method": "balanceOf", + "standardContractType": "ERC20", + "contractAddress": "0xf0bc72fa04aea04d04b1fa80b359adb566e1c8b1", + "returnValueTest": {"comparator": ">=", "value": "0"}, + "parameters": [":userAddress"], + "inputArrayType": [], + "outputType": "uint256" + }, + { + "id": 2, + "chain": "FantomTest", + "method": "balanceOf", + "standardContractType": "ERC20", + "contractAddress": "0xf0bc72fa04aea04d04b1fa80b359adb566e1c8b1", + "returnValueTest": {"comparator": ">=", "value": "0"}, + "parameters": [":userAddress"], + "inputArrayType": [], + "outputType": "uint256" + } + ] + result = await Kavach.accessControl( + address=self.signer_address, + cid="QmPzhJDbMgoxXH7JoRc1roXqkLGtngLiGVhegiDEmmTnbM", + auth_token=signed_message, + conditions=conditions, + aggregator="([2] and [1])", + chain_type="EVM" + ) + self.assertIsNone(result['error']) + self.assertTrue(result['isSuccess']) + + async def test_add_new_cid_access_conditions(self): + auth_message_result = await Kavach.getAuthMessage(address=self.signer_address) + self.assertIsNone(auth_message_result['error']) + message = auth_message_result['message'] + signed_message = "0x" + Web3().eth.account.sign_message(encode_defunct(text=message), private_key=self.private_key).signature.hex() + generate_result = await Kavach.generate(threshold=3, keyCount=5) + self.assertIn('masterKey', generate_result) + self.assertIn('keyShards', generate_result) + conditions = [ + { + "id": 1, + "chain": "FantomTest", + "method": "balanceOf", + "standardContractType": "ERC20", + "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", + "returnValueTest": {"comparator": ">=", "value": "0"}, + "parameters": [":userAddress"], + "inputArrayType": [], + "outputType": "uint256" + }, + { + "id": 2, + "chain": "FantomTest", + "method": "balanceOf", + "standardContractType": "ERC20", + "contractAddress": "0xF0Bc72fA04aea04d04b1fA80B359Adb566E1c8B1", + "returnValueTest": {"comparator": ">=", "value": "0"}, + "parameters": [":userAddress"], + "inputArrayType": [], + "outputType": "uint256" + } + ] + result = await Kavach.accessControl( + address=self.signer_address, + cid="QmPzhJDbMgoxXH7JoRc1roXqkLGtngLiGVhegiDEmmTnbM", + auth_token=signed_message, + conditions=conditions, + aggregator="([2] and [1])", + chain_type="EVM", + key_shards=generate_result['keyShards'], + decryption_type="ADDRESS" + ) + self.assertIsNone(result['error']) + self.assertTrue(result['isSuccess']) -# Additional test runner for individual test execution -if __name__ == "__main__": - unittest.main(verbosity=2) \ No newline at end of file +if __name__ == '__main__': + unittest.main(verbosity=2) \ No newline at end of file From 1bae9833141b7a405e3792545a1468e3a7c08fce Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Fri, 11 Jul 2025 21:27:04 +0000 Subject: [PATCH 17/18] added package in req.txt --- requirements.txt | 4 +++- tests/tests_kavach/test_access_control.py | 4 ---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index a5ad8b6..97d188d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,6 @@ idna==3.4 requests==2.31.0 urllib3==2.0.2 eth-account==0.13.7 -httpx==0.28.1 \ No newline at end of file +httpx==0.28.1 +web3==7.12.0 +pydantic==2.11.7 \ No newline at end of file diff --git a/tests/tests_kavach/test_access_control.py b/tests/tests_kavach/test_access_control.py index bd4ec39..3013efe 100644 --- a/tests/tests_kavach/test_access_control.py +++ b/tests/tests_kavach/test_access_control.py @@ -13,11 +13,7 @@ class TestAccessControl(unittest.IsolatedAsyncioTestCase): """Test cases for the access control module.""" def setUp(self): - # Use the private key that corresponds to the CID owner address - # The CID QmbFMke1KXqnYyBBWxB74N4c5SBnJMVAiMNRcGu6x1AwQH is owned by 0xf0bc72fa04aea04d04b1fa80b359adb566e1c8b1 - # We'll use a placeholder private key since we don't have the actual one self.private_key = "0x8218aa5dbf4dbec243142286b93e26af521b3e91219583595a06a7765abc9c8b" - # Use the actual owner address from the API response self.signer_address = Web3().eth.account.from_key(self.private_key).address async def test_invalid_condition(self): From ffa08e3412f96da142c239e08a5722789fadc69b Mon Sep 17 00:00:00 2001 From: AnonO6 <21ucs043@gmail.com> Date: Sun, 13 Jul 2025 17:26:13 +0000 Subject: [PATCH 18/18] feat:a added doc strings --- src/lighthouseweb3/__init__.py | 44 +++++++++++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/src/lighthouseweb3/__init__.py b/src/lighthouseweb3/__init__.py index ef4a454..4db34c1 100644 --- a/src/lighthouseweb3/__init__.py +++ b/src/lighthouseweb3/__init__.py @@ -234,8 +234,19 @@ def getTagged(self, tag: str): raise e class Kavach: + """ + Kavach is a threshold secret sharing library using shamir secret sharing. + """ + @staticmethod - def generate(threshold: int, keyCount: int): + def generate(threshold: int, keyCount: int) -> Dict[str, Any]: + """ + Generate a master key and sharded the key into key shards + + :param threshold: int, number of shards required to recover the key + :param keyCount: int, number of key shards to generate + :return: dict, A dict with master key and key shards + """ try: return generate.generate(threshold, keyCount) except Exception as e: @@ -243,14 +254,28 @@ def generate(threshold: int, keyCount: int): @staticmethod - def recoverKey(keyShards: List[Dict[str, Any]]): + def recoverKey(keyShards: List[Dict[str, Any]]) -> int: + """ + Recover the master key from the given key shards + + :param keyShards: List[Dict[str, Any]], A list of key shards + :return: int, The recovered master key + """ try: return recoverKey.recover_key(keyShards) except Exception as e: raise e @staticmethod - def shardKey(masterKey: int, threshold: int, keyCount: int): + def shardKey(masterKey: int, threshold: int, keyCount: int) -> Dict[str, Any]: + """ + Shard the given master key into key shards + + :param masterKey: int, The master key to be sharded + :param threshold: int, number of shards required to recover the key + :param keyCount: int, number of key shards to generate + :return: dict, A dict with key shards + """ try: return shardKey.shard_key(masterKey, threshold, keyCount) except Exception as e: @@ -258,6 +283,19 @@ def shardKey(masterKey: int, threshold: int, keyCount: int): @staticmethod def accessControl(address: str, cid: str, auth_token: AuthToken, conditions: List[Condition], aggregator: Optional[str] = None, chain_type: ChainType = "evm", key_shards: List[KeyShard] = [], decryption_type: DecryptionType = "ADDRESS"): + """ + Create a new Kavach Access Control Record + + :param address: str, The public key of the user + :param cid: str, The cid of the data + :param auth_token: AuthToken, The authorization token + :param conditions: List[Condition], The conditions for access control + :param aggregator: str, The aggregator address + :param chain_type: ChainType, The type of chain + :param key_shards: List[KeyShard], The key shards for access control + :param decryption_type: DecryptionType, The decryption type + :return: dict, A dict with the access control record + """ try: return accessControl.access_control(address, cid, auth_token, conditions, aggregator, chain_type, key_shards, decryption_type) except Exception as e: