diff --git a/contracts/contracts/coordination/Coordinator.sol b/contracts/contracts/coordination/Coordinator.sol index ee3e6598..097d9e48 100644 --- a/contracts/contracts/coordination/Coordinator.sol +++ b/contracts/contracts/coordination/Coordinator.sol @@ -479,6 +479,32 @@ contract Coordinator is Initializable, AccessControlDefaultAdminRulesUpgradeable return (participant, index); } + function getParticipants( + uint32 ritualId, + uint256 _startIndex, + uint256 _maxParticipants, + bool _includeTranscript + ) external view returns (Participant[] memory) { + Ritual storage ritual = rituals[ritualId]; + uint256 endIndex = ritual.participant.length; + require(_startIndex < endIndex, "Wrong start index"); + if (_maxParticipants != 0 && _startIndex + _maxParticipants < endIndex) { + endIndex = _startIndex + _maxParticipants; + } + Participant[] memory ritualParticipants = new Participant[](endIndex - _startIndex); + + uint256 resultIndex = 0; + for (uint256 i = _startIndex; i < endIndex; i++) { + Participant memory ritualParticipant = ritual.participant[i]; + if (!_includeTranscript) { + ritualParticipant.transcript = ""; + } + ritualParticipants[resultIndex++] = ritualParticipant; + } + + return ritualParticipants; + } + function getProviders(uint32 ritualId) external view returns (address[] memory) { Ritual storage ritual = rituals[ritualId]; address[] memory providers = new address[](ritual.participant.length); diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 40e8d1c1..3823ba86 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -318,6 +318,52 @@ def test_post_transcript_but_not_waiting_for_transcripts( coordinator.postTranscript(0, transcript, sender=nodes[1]) +def test_get_participants(coordinator, nodes, initiator, erc20, global_allow_list): + initiate_ritual( + coordinator=coordinator, + erc20=erc20, + authority=initiator, + nodes=nodes, + allow_logic=global_allow_list, + ) + transcript = os.urandom(transcript_size(len(nodes), len(nodes))) + + for node in nodes: + tx = coordinator.postTranscript(0, transcript, sender=node) + + # get all participants + participants = coordinator.getParticipants(0, 0, len(nodes), False) + assert len(participants) == len(nodes) + for index, participant in enumerate(participants): + assert participant.provider == nodes[index].address + assert participant.aggregated is False + assert not participant.transcript + + # max is higher than available + participants = coordinator.getParticipants(0, 0, len(nodes)*2, False) + assert len(participants) == len(nodes) + for index, participant in enumerate(participants): + assert participant.provider == nodes[index].address + assert participant.aggregated is False + assert not participant.transcript + + # n at a time + for n_at_a_time in [2]: + index = 0 + while index < len(nodes): + print(f">>>> Start Index {index}, End Index {index+n_at_a_time}") + participants_n_at_a_time = coordinator.getParticipants(0, index, index+n_at_a_time, True) + assert len(participants_n_at_a_time) <= n_at_a_time + for i, participant in enumerate(participants_n_at_a_time): + assert participant.provider == nodes[index+i].address + assert participant.aggregated is False + assert participant.transcript == transcript + + index += n_at_a_time + + assert index == len(nodes) + + def test_post_aggregation( coordinator, nodes, initiator, erc20, global_allow_list, treasury, deployer ):