diff --git a/pycardano/serialization.py b/pycardano/serialization.py index 75c1c1fa..a82e5455 100644 --- a/pycardano/serialization.py +++ b/pycardano/serialization.py @@ -24,7 +24,6 @@ List, Optional, Sequence, - Set, Type, TypeVar, Union, @@ -160,6 +159,7 @@ class RawCBOR: Fraction, FrozenList, IndefiniteFrozenList, + ByteString, ) """ A list of types that could be encoded by @@ -1128,26 +1128,53 @@ def list_hook( return lambda vals: [cls.from_primitive(v) for v in vals] -class OrderedSet(list, Generic[T], CBORSerializable): - def __init__(self, iterable: Optional[List[T]] = None, use_tag: bool = True): +class OrderedSet(Generic[T], CBORSerializable): + def __init__( + self, + iterable: Optional[Union[List[T], IndefiniteList]] = None, + use_tag: bool = True, + ): super().__init__() - self._set: Set[str] = set() + self._dict: Dict[bytes, int] = {} + self._list: List[T] = [] self._use_tag = use_tag + self._is_indefinite_list = False if iterable: + self._is_indefinite_list = isinstance(iterable, IndefiniteList) self.extend(iterable) def append(self, item: T) -> None: - item_key = str(item) - if item_key not in self._set: - super().append(item) - self._set.add(item_key) + if item in self: + return + self._list.append(item) + self._dict[dumps(item, default=default_encoder)] = len(self._list) - 1 def extend(self, items: Iterable[T]) -> None: + self._is_indefinite_list = isinstance(items, IndefiniteList) for item in items: self.append(item) + def remove(self, item: T) -> None: + if item not in self: + return + index = self._dict.pop(dumps(item, default=default_encoder)) + self._list.pop(index) + # Update the indices in the dictionary + for key, idx in self._dict.items(): + if idx > index: + self._dict[key] = idx - 1 + def __contains__(self, item: object) -> bool: - return str(item) in self._set + return dumps(item, default=default_encoder) in self._dict + + def __iter__(self): + return iter(self._list) + + def __getitem__(self, index: int) -> T: + return self._list[index] + + def __len__(self) -> int: + return len(self._list) def __eq__(self, other: object) -> bool: if not isinstance(other, OrderedSet): @@ -1159,10 +1186,13 @@ def __eq__(self, other: object) -> bool: def __repr__(self) -> str: return f"{self.__class__.__name__}({list(self)})" - def to_shallow_primitive(self) -> Union[CBORTag, List[T]]: + def to_shallow_primitive(self) -> Union[CBORTag, Union[List[T], IndefiniteList]]: if self._use_tag: - return CBORTag(258, list(self)) - return list(self) + return CBORTag( + 258, + IndefiniteList(list(self)) if self._is_indefinite_list else list(self), + ) + return IndefiniteList(list(self)) if self._is_indefinite_list else list(self) @classmethod def from_primitive( @@ -1195,7 +1225,11 @@ def __deepcopy__(self, memo): class NonEmptyOrderedSet(OrderedSet[T]): - def __init__(self, iterable: Optional[List[T]] = None, use_tag: bool = True): + def __init__( + self, + iterable: Optional[Union[List[T], IndefiniteList]] = None, + use_tag: bool = True, + ): super().__init__(iterable, use_tag) def validate(self): diff --git a/pycardano/txbuilder.py b/pycardano/txbuilder.py index 1da8d8c6..d99fb20d 100644 --- a/pycardano/txbuilder.py +++ b/pycardano/txbuilder.py @@ -2,7 +2,7 @@ from copy import deepcopy from dataclasses import dataclass, field, fields -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from pycardano import RedeemerMap from pycardano.address import Address, AddressType @@ -616,7 +616,9 @@ def script_data_hash(self) -> Optional[ScriptDataHash]: ) ) return script_data_hash( - self.redeemers(), list(self.datums.values()), CostModels(cost_models) + self.redeemers(), + NonEmptyOrderedSet(list(self.datums.values())), + CostModels(cost_models), ) else: return None @@ -1170,6 +1172,7 @@ def build_witness_set( plutus_v1_scripts: NonEmptyOrderedSet[PlutusV1Script] = NonEmptyOrderedSet() plutus_v2_scripts: NonEmptyOrderedSet[PlutusV2Script] = NonEmptyOrderedSet() plutus_v3_scripts: NonEmptyOrderedSet[PlutusV3Script] = NonEmptyOrderedSet() + plutus_data: NonEmptyOrderedSet[Any] = NonEmptyOrderedSet() input_scripts = ( { @@ -1181,6 +1184,9 @@ def build_witness_set( else {} ) + for datum in self.datums.values(): + plutus_data.append(datum) + for script in self.scripts: if script_hash(script) not in input_scripts: if isinstance(script, NativeScript): @@ -1204,7 +1210,7 @@ def build_witness_set( plutus_v2_script=plutus_v2_scripts if plutus_v2_scripts else None, plutus_v3_script=plutus_v3_scripts if plutus_v3_scripts else None, redeemer=self.redeemers() if self._redeemer_list else None, - plutus_data=list(self.datums.values()) if self.datums else None, + plutus_data=plutus_data if plutus_data else None, ) def _ensure_no_input_exclusion_conflict(self): diff --git a/pycardano/utils.py b/pycardano/utils.py index 7c912160..00c5f6a4 100644 --- a/pycardano/utils.py +++ b/pycardano/utils.py @@ -12,8 +12,8 @@ from pycardano.backend.base import ChainContext from pycardano.hash import SCRIPT_DATA_HASH_SIZE, SCRIPT_HASH_SIZE, ScriptDataHash -from pycardano.plutus import COST_MODELS, CostModels, Datum, Redeemers -from pycardano.serialization import default_encoder +from pycardano.plutus import COST_MODELS, CostModels, Datum, RedeemerMap, Redeemers +from pycardano.serialization import NonEmptyOrderedSet, default_encoder from pycardano.transaction import MultiAsset, TransactionOutput, Value __all__ = [ @@ -235,30 +235,35 @@ def min_lovelace_post_alonzo(output: TransactionOutput, context: ChainContext) - def script_data_hash( - redeemers: Redeemers, - datums: List[Datum], + redeemers: Optional[Redeemers] = None, + datums: Optional[Union[List[Datum], NonEmptyOrderedSet[Datum]]] = None, cost_models: Optional[Union[CostModels, Dict]] = None, ) -> ScriptDataHash: """Calculate plutus script data hash Args: - redeemers (Redeemers): Redeemers to include. - datums (List[Datum]): Datums to include. + redeemers (Optional[Redeemers]): Redeemers to include. + datums (Optional[Union[List[Datum], NonEmptyOrderedSet[Datum]]]): Datums to include. cost_models (Optional[CostModels]): Cost models. Returns: ScriptDataHash: Plutus script data hash """ - if not redeemers: + if redeemers is None: + redeemers = RedeemerMap() + cost_models = {} + elif len(redeemers) == 0: cost_models = {} elif not cost_models: cost_models = COST_MODELS redeemer_bytes = cbor2.dumps(redeemers, default=default_encoder) + if datums: datum_bytes = cbor2.dumps(datums, default=default_encoder) else: datum_bytes = b"" + cost_models_bytes = cbor2.dumps(cost_models, default=default_encoder) return ScriptDataHash( diff --git a/pycardano/witness.py b/pycardano/witness.py index 3d86e1a5..54a94ce7 100644 --- a/pycardano/witness.py +++ b/pycardano/witness.py @@ -9,19 +9,13 @@ from pycardano.key import ExtendedVerificationKey, VerificationKey from pycardano.nativescript import NativeScript -from pycardano.plutus import ( - PlutusV1Script, - PlutusV2Script, - PlutusV3Script, - RawPlutusData, - Redeemers, -) +from pycardano.plutus import PlutusV1Script, PlutusV2Script, PlutusV3Script, Redeemers from pycardano.serialization import ( ArrayCBORSerializable, + IndefiniteList, MapCBORSerializable, NonEmptyOrderedSet, limit_primitive_type, - list_hook, ) __all__ = ["VerificationKeyWitness", "TransactionWitnessSet"] @@ -114,9 +108,11 @@ class TransactionWitnessSet(MapCBORSerializable): }, ) - plutus_data: Optional[List[Any]] = field( - default=None, - metadata={"optional": True, "key": 4, "object_hook": list_hook(RawPlutusData)}, + plutus_data: Optional[Union[IndefiniteList, List[Any], NonEmptyOrderedSet[Any]]] = ( + field( + default=None, + metadata={"optional": True, "key": 4}, + ) ) redeemer: Optional[Redeemers] = field( @@ -150,6 +146,10 @@ def __post_init__(self): self.vkey_witnesses = NonEmptyOrderedSet(self.vkey_witnesses) if isinstance(self.native_scripts, list): self.native_scripts = NonEmptyOrderedSet(self.native_scripts) + if isinstance(self.plutus_data, list) and not isinstance( + self.plutus_data, NonEmptyOrderedSet + ): + self.plutus_data = NonEmptyOrderedSet(list(self.plutus_data)) if isinstance(self.plutus_v1_script, list): self.plutus_v1_script = NonEmptyOrderedSet(self.plutus_v1_script) if isinstance(self.plutus_v2_script, list): diff --git a/test/pycardano/test_serialization.py b/test/pycardano/test_serialization.py index c75503c5..5cbf5df2 100644 --- a/test/pycardano/test_serialization.py +++ b/test/pycardano/test_serialization.py @@ -638,6 +638,23 @@ def test_ordered_set(): assert list(s) == [1, 2, 3] assert s._use_tag + # Test remove + s = OrderedSet([1, 2, 3, 4]) + s.remove(2) + assert list(s) == [1, 3, 4] + assert 2 not in s + assert 1 in s + assert 3 in s + assert 4 in s + s.remove(2) + assert list(s) == [1, 3, 4] + assert 2 not in s + s.remove(3) + assert list(s) == [1, 4] + assert 3 not in s + s.remove(4) + assert list(s) == [1] + def test_ordered_set_with_complex_types(): # Test with VerificationKeyWitness @@ -909,9 +926,9 @@ class MyOrderedSet(OrderedSet): assert 4 not in s # Test with complex objects - class TestObj: - def __init__(self, value): - self.value = value + @dataclass(repr=False) + class TestObj(ArrayCBORSerializable): + value: str def __str__(self): return f"TestObj({self.value})" diff --git a/test/pycardano/test_util.py b/test/pycardano/test_util.py index 17b4f647..8baa6741 100644 --- a/test/pycardano/test_util.py +++ b/test/pycardano/test_util.py @@ -2,8 +2,19 @@ import pytest +from pycardano import NonEmptyOrderedSet from pycardano.hash import SCRIPT_HASH_SIZE, ScriptDataHash -from pycardano.plutus import ExecutionUnits, PlutusData, Redeemer, RedeemerTag, Unit +from pycardano.plutus import ( + COST_MODELS, + ExecutionUnits, + PlutusData, + Redeemer, + RedeemerKey, + RedeemerMap, + RedeemerTag, + RedeemerValue, + Unit, +) from pycardano.transaction import Value from pycardano.utils import ( min_lovelace_pre_alonzo, @@ -160,6 +171,25 @@ def test_script_data_hash(): ) == script_data_hash(redeemers=redeemers, datums=[unit]) +def test_script_data_hash_redeemer_map(): + unit = Unit() + redeemer = Redeemer(42, ExecutionUnits(573240, 253056459)) + redeemer.tag = RedeemerTag.SPEND + redeemers = RedeemerMap( + { + RedeemerKey(redeemer.tag, redeemer.index): RedeemerValue( + redeemer.data, redeemer.ex_units + ) + } + ) + cost_models = COST_MODELS + assert ScriptDataHash.from_primitive( + "04ad5eb241d1ede2bbbd60c5853de7659d2ecfb1a29d6cbb6921ef7bdd46ca3c" + ) == script_data_hash( + redeemers=redeemers, datums=NonEmptyOrderedSet([unit]), cost_models=cost_models + ) + + def test_script_data_hash_datum_only(): unit = Unit() assert ScriptDataHash.from_primitive(