Skip to content

Commit dd01041

Browse files
committed
Fix OrderedSet serialization
1 parent fb6cdeb commit dd01041

File tree

4 files changed

+51
-19
lines changed

4 files changed

+51
-19
lines changed

pycardano/serialization.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class RawCBOR:
160160
Fraction,
161161
FrozenList,
162162
IndefiniteFrozenList,
163+
ByteString,
163164
)
164165
"""
165166
A list of types that could be encoded by
@@ -1128,33 +1129,53 @@ def list_hook(
11281129
return lambda vals: [cls.from_primitive(v) for v in vals]
11291130

11301131

1131-
class OrderedSet(list, IndefiniteList, Generic[T], CBORSerializable): # type: ignore
1132+
class OrderedSet(Generic[T], CBORSerializable): # type: ignore
11321133
def __init__(
11331134
self,
11341135
iterable: Optional[Union[List[T], IndefiniteList]] = None,
11351136
use_tag: bool = True,
11361137
):
11371138
super().__init__()
1138-
self._set: Set[str] = set()
1139+
self._dict: Dict[bytes, int] = {}
1140+
self._list: List[T] = []
11391141
self._use_tag = use_tag
11401142
self._is_indefinite_list = False
11411143
if iterable:
11421144
self._is_indefinite_list = isinstance(iterable, IndefiniteList)
11431145
self.extend(iterable)
11441146

11451147
def append(self, item: T) -> None:
1146-
item_key = str(item)
1147-
if item_key not in self._set:
1148-
super().append(item)
1149-
self._set.add(item_key)
1148+
if item in self:
1149+
return
1150+
self._list.append(item)
1151+
self._dict[dumps(item, default=default_encoder)] = len(self._list) - 1
11501152

11511153
def extend(self, items: Iterable[T]) -> None:
11521154
self._is_indefinite_list = isinstance(items, IndefiniteList)
11531155
for item in items:
11541156
self.append(item)
11551157

1158+
def remove(self, item: T) -> None:
1159+
if item not in self:
1160+
return
1161+
index = self._dict.pop(dumps(item, default=default_encoder))
1162+
self._list.pop(index)
1163+
# Update the indices in the dictionary
1164+
for key, idx in self._dict.items():
1165+
if idx > index:
1166+
self._dict[key] = idx - 1
1167+
11561168
def __contains__(self, item: object) -> bool:
1157-
return str(item) in self._set
1169+
return dumps(item, default=default_encoder) in self._dict
1170+
1171+
def __iter__(self):
1172+
return iter(self._list)
1173+
1174+
def __getitem__(self, index: int) -> T:
1175+
return self._list[index]
1176+
1177+
def __len__(self) -> int:
1178+
return len(self._list)
11581179

11591180
def __eq__(self, other: object) -> bool:
11601181
if not isinstance(other, OrderedSet):

pycardano/utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,7 @@ def script_data_hash(
260260
redeemer_bytes = cbor2.dumps(redeemers, default=default_encoder)
261261

262262
if datums:
263-
if not isinstance(datums, NonEmptyOrderedSet):
264-
# If datums is not a NonEmptyOrderedSet, handle it as a list
265-
datum_bytes = cbor2.dumps(datums, default=default_encoder)
266-
else:
267-
datum_bytes = cbor2.dumps(
268-
datums.to_shallow_primitive(), default=default_encoder
269-
)
263+
datum_bytes = cbor2.dumps(datums, default=default_encoder)
270264
else:
271265
datum_bytes = b""
272266

test/pycardano/test_serialization.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,23 @@ def test_ordered_set():
638638
assert list(s) == [1, 2, 3]
639639
assert s._use_tag
640640

641+
# Test remove
642+
s = OrderedSet([1, 2, 3, 4])
643+
s.remove(2)
644+
assert list(s) == [1, 3, 4]
645+
assert 2 not in s
646+
assert 1 in s
647+
assert 3 in s
648+
assert 4 in s
649+
s.remove(2)
650+
assert list(s) == [1, 3, 4]
651+
assert 2 not in s
652+
s.remove(3)
653+
assert list(s) == [1, 4]
654+
assert 3 not in s
655+
s.remove(4)
656+
assert list(s) == [1]
657+
641658

642659
def test_ordered_set_with_complex_types():
643660
# Test with VerificationKeyWitness
@@ -909,9 +926,9 @@ class MyOrderedSet(OrderedSet):
909926
assert 4 not in s
910927

911928
# Test with complex objects
912-
class TestObj:
913-
def __init__(self, value):
914-
self.value = value
929+
@dataclass(repr=False)
930+
class TestObj(ArrayCBORSerializable):
931+
value: str
915932

916933
def __str__(self):
917934
return f"TestObj({self.value})"

test/pycardano/test_txbuilder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,12 +411,12 @@ def test_tx_builder_mint_multi_asset(chain_context):
411411
[
412412
sender_address.to_primitive(),
413413
[
414-
5809111,
414+
5809155,
415415
{b"1111111111111111111111111111": {b"Token1": 1, b"Token2": 2}},
416416
],
417417
],
418418
],
419-
2: 190889,
419+
2: 190845,
420420
3: 123456789,
421421
8: 1000,
422422
9: mint,

0 commit comments

Comments
 (0)