Skip to content

Commit d63d31a

Browse files
committed
Fix OrderedSet serialization
1 parent fb6cdeb commit d63d31a

File tree

4 files changed

+51
-20
lines changed

4 files changed

+51
-20
lines changed

pycardano/serialization.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
List,
2525
Optional,
2626
Sequence,
27-
Set,
2827
Type,
2928
TypeVar,
3029
Union,
@@ -160,6 +159,7 @@ class RawCBOR:
160159
Fraction,
161160
FrozenList,
162161
IndefiniteFrozenList,
162+
ByteString,
163163
)
164164
"""
165165
A list of types that could be encoded by
@@ -1128,33 +1128,53 @@ def list_hook(
11281128
return lambda vals: [cls.from_primitive(v) for v in vals]
11291129

11301130

1131-
class OrderedSet(list, IndefiniteList, Generic[T], CBORSerializable): # type: ignore
1131+
class OrderedSet(Generic[T], CBORSerializable): # type: ignore
11321132
def __init__(
11331133
self,
11341134
iterable: Optional[Union[List[T], IndefiniteList]] = None,
11351135
use_tag: bool = True,
11361136
):
11371137
super().__init__()
1138-
self._set: Set[str] = set()
1138+
self._dict: Dict[bytes, int] = {}
1139+
self._list: List[T] = []
11391140
self._use_tag = use_tag
11401141
self._is_indefinite_list = False
11411142
if iterable:
11421143
self._is_indefinite_list = isinstance(iterable, IndefiniteList)
11431144
self.extend(iterable)
11441145

11451146
def append(self, item: T) -> None:
1146-
item_key = str(item)
1147-
if item_key not in self._set:
1148-
super().append(item)
1149-
self._set.add(item_key)
1147+
if item in self:
1148+
return
1149+
self._list.append(item)
1150+
self._dict[dumps(item, default=default_encoder)] = len(self._list) - 1
11501151

11511152
def extend(self, items: Iterable[T]) -> None:
11521153
self._is_indefinite_list = isinstance(items, IndefiniteList)
11531154
for item in items:
11541155
self.append(item)
11551156

1157+
def remove(self, item: T) -> None:
1158+
if item not in self:
1159+
return
1160+
index = self._dict.pop(dumps(item, default=default_encoder))
1161+
self._list.pop(index)
1162+
# Update the indices in the dictionary
1163+
for key, idx in self._dict.items():
1164+
if idx > index:
1165+
self._dict[key] = idx - 1
1166+
11561167
def __contains__(self, item: object) -> bool:
1157-
return str(item) in self._set
1168+
return dumps(item, default=default_encoder) in self._dict
1169+
1170+
def __iter__(self):
1171+
return iter(self._list)
1172+
1173+
def __getitem__(self, index: int) -> T:
1174+
return self._list[index]
1175+
1176+
def __len__(self) -> int:
1177+
return len(self._list)
11581178

11591179
def __eq__(self, other: object) -> bool:
11601180
if not isinstance(other, OrderedSet):

pycardano/utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,7 @@ def script_data_hash(
260260
redeemer_bytes = cbor2.dumps(redeemers, default=default_encoder)
261261

262262
if datums:
263-
if not isinstance(datums, NonEmptyOrderedSet):
264-
# If datums is not a NonEmptyOrderedSet, handle it as a list
265-
datum_bytes = cbor2.dumps(datums, default=default_encoder)
266-
else:
267-
datum_bytes = cbor2.dumps(
268-
datums.to_shallow_primitive(), default=default_encoder
269-
)
263+
datum_bytes = cbor2.dumps(datums, default=default_encoder)
270264
else:
271265
datum_bytes = b""
272266

test/pycardano/test_serialization.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,23 @@ def test_ordered_set():
638638
assert list(s) == [1, 2, 3]
639639
assert s._use_tag
640640

641+
# Test remove
642+
s = OrderedSet([1, 2, 3, 4])
643+
s.remove(2)
644+
assert list(s) == [1, 3, 4]
645+
assert 2 not in s
646+
assert 1 in s
647+
assert 3 in s
648+
assert 4 in s
649+
s.remove(2)
650+
assert list(s) == [1, 3, 4]
651+
assert 2 not in s
652+
s.remove(3)
653+
assert list(s) == [1, 4]
654+
assert 3 not in s
655+
s.remove(4)
656+
assert list(s) == [1]
657+
641658

642659
def test_ordered_set_with_complex_types():
643660
# Test with VerificationKeyWitness
@@ -909,9 +926,9 @@ class MyOrderedSet(OrderedSet):
909926
assert 4 not in s
910927

911928
# Test with complex objects
912-
class TestObj:
913-
def __init__(self, value):
914-
self.value = value
929+
@dataclass(repr=False)
930+
class TestObj(ArrayCBORSerializable):
931+
value: str
915932

916933
def __str__(self):
917934
return f"TestObj({self.value})"

test/pycardano/test_txbuilder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,12 +411,12 @@ def test_tx_builder_mint_multi_asset(chain_context):
411411
[
412412
sender_address.to_primitive(),
413413
[
414-
5809111,
414+
5809155,
415415
{b"1111111111111111111111111111": {b"Token1": 1, b"Token2": 2}},
416416
],
417417
],
418418
],
419-
2: 190889,
419+
2: 190845,
420420
3: 123456789,
421421
8: 1000,
422422
9: mint,

0 commit comments

Comments
 (0)