Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Initial support for cooperative-sticky rebalancing #407

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions arroyo/backends/kafka/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
)

configuration = dict(configuration)
self.__assignment_strategy = configuration.get("partition.assignment.strategy")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry i said the wrong thing earlier, this should be group.protocol

auto_offset_reset = configuration.get("auto.offset.reset", "largest")

# This is a special flag that controls the auto offset behavior for
Expand Down Expand Up @@ -269,6 +270,10 @@ def subscribe(
def assignment_callback(
consumer: ConfluentConsumer, partitions: Sequence[ConfluentTopicPartition]
) -> None:
if not partitions and self.__assignment_strategy == "cooperative-sticky":
logger.info("skipping empty assignment")
return

self.__state = KafkaConsumerState.ASSIGNING

try:
Expand Down Expand Up @@ -451,12 +456,14 @@ def __validate_offsets(self, offsets: Mapping[Partition, int]) -> None:

def __assign(self, offsets: Mapping[Partition, int]) -> None:
self.__validate_offsets(offsets)
self.__consumer.assign(
[
ConfluentTopicPartition(partition.topic.name, partition.index, offset)
for partition, offset in offsets.items()
]
)
partitions = [
ConfluentTopicPartition(partition.topic.name, partition.index, offset)
for partition, offset in offsets.items()
]
if self.__assignment_strategy == "cooperative-sticky":
self.__consumer.incremental_assign(partitions)
else:
self.__consumer.assign(partitions)
self.__offsets.update(offsets)

def seek(self, offsets: Mapping[Partition, int]) -> None:
Expand Down
10 changes: 5 additions & 5 deletions arroyo/dlq.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,11 @@ def pop(

return None

def reset(self) -> None:
def remove(self, partition: Partition) -> None:
"""
Reset the buffer.
Remove a revoked partition from the buffer.
"""
self.__buffered_messages = defaultdict(deque)
self.__buffered_messages.pop(partition, None)


class DlqPolicyWrapper(Generic[TStrategyPayload]):
Expand All @@ -343,9 +343,9 @@ def __init__(
]
],
] = defaultdict(deque)
self.reset_offsets({})
self.reset_dlq_limits({})

def reset_offsets(self, assignment: Mapping[Partition, int]) -> None:
def reset_dlq_limits(self, assignment: Mapping[Partition, int]) -> None:
"""
Called on consumer assignment
"""
Expand Down
20 changes: 15 additions & 5 deletions arroyo/processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,23 @@ def on_partitions_assigned(partitions: Mapping[Partition, int]) -> None:
"arroyo.consumer.partitions_assigned.count", len(partitions)
)

self.__buffered_messages.reset()
current_partitions = dict(self.__consumer.tell())
current_partitions.update(partitions)

if self.__dlq_policy:
self.__dlq_policy.reset_offsets(partitions)
if partitions:
self.__dlq_policy.reset_dlq_limits(current_partitions)
if current_partitions:
if self.__processing_strategy is not None:
logger.exception(
# TODO: for cooperative-sticky rebalancing this can happen
# quite often. we should port the changes to
# ProcessingStrategyFactory that we made in Rust: Remove
# create_with_partitions, replace with create +
# update_partitions
logger.error(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we downgrade to warning if this is expected to happen on a regular basis?

"Partition assignment while processing strategy active"
)
_close_strategy()
_create_strategy(partitions)
_create_strategy(current_partitions)

@_rdkafka_callback(metrics=self.__metrics_buffer)
def on_partitions_revoked(partitions: Sequence[Partition]) -> None:
Expand Down Expand Up @@ -278,6 +285,9 @@ def on_partitions_revoked(partitions: Sequence[Partition]) -> None:
except RuntimeError:
pass

for partition in partitions:
self.__buffered_messages.remove(partition)

# Partition revocation can happen anytime during the consumer lifecycle and happen
# multiple times. What we want to know is that the consumer is not stuck somewhere.
# The presence of this message as the last message of a consumer
Expand Down
99 changes: 84 additions & 15 deletions tests/backends/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from abc import ABC, abstractmethod
from contextlib import closing
from typing import ContextManager, Generic, Iterator, Mapping, Optional, Sequence
from typing import Any, ContextManager, Generic, Iterator, Mapping, Optional, Sequence
from unittest import mock

import pytest
Expand All @@ -14,6 +14,8 @@


class StreamsTestMixin(ABC, Generic[TStrategyPayload]):
incremental_rebalancing = False

@abstractmethod
def get_topic(self, partitions: int = 1) -> ContextManager[Topic]:
raise NotImplementedError
Expand Down Expand Up @@ -397,6 +399,11 @@ def test_pause_resume(self) -> None:
def test_pause_resume_rebalancing(self) -> None:
payloads = self.get_payloads()

consumer_a_on_assign = mock.Mock()
consumer_a_on_revoke = mock.Mock()
consumer_b_on_assign = mock.Mock()
consumer_b_on_revoke = mock.Mock()

with self.get_topic(2) as topic, closing(
self.get_producer()
) as producer, closing(
Expand All @@ -408,10 +415,22 @@ def test_pause_resume_rebalancing(self) -> None:
producer.produce(Partition(topic, i), next(payloads)).result(
timeout=5.0
)
for i in range(2)
for i in [0, 1]
]

consumer_a.subscribe([topic])
def wait_until_rebalancing(
from_consumer: Consumer[Any], to_consumer: Consumer[Any]
) -> None:
for _ in range(10):
assert from_consumer.poll(0) is None
if to_consumer.poll(1.0) is not None:
return

raise RuntimeError("no rebalancing happened")

consumer_a.subscribe(
[topic], on_assign=consumer_a_on_assign, on_revoke=consumer_a_on_revoke
)

# It doesn't really matter which message is fetched first -- we
# just want to know the assignment occurred.
Expand All @@ -428,19 +447,69 @@ def test_pause_resume_rebalancing(self) -> None:
[Partition(topic, 0), Partition(topic, 1)]
)

consumer_b.subscribe([topic])
for i in range(10):
assert consumer_a.poll(0) is None # attempt to force session timeout
if consumer_b.poll(1.0) is not None:
break
else:
assert False, "rebalance did not occur"
consumer_b.subscribe(
[topic], on_assign=consumer_b_on_assign, on_revoke=consumer_b_on_revoke
)

wait_until_rebalancing(consumer_a, consumer_b)

# The first consumer should have had its offsets rolled back, as
# well as should have had it's partition resumed during
# rebalancing.
assert consumer_a.paused() == []
assert consumer_a.poll(10.0) is not None
if self.incremental_rebalancing:
# within incremental rebalancing, only one partition should have been reassigned to the consumer_b, and consumer_a should remain paused
assert consumer_a.paused() == [Partition(topic, 1)]
assert consumer_a.poll(10.0) is None
else:
# The first consumer should have had its offsets rolled back, as
# well as should have had it's partition resumed during
# rebalancing.
assert consumer_a.paused() == []
assert consumer_a.poll(10.0) is not None

assert len(consumer_a.tell()) == 1
assert len(consumer_b.tell()) == 1

(consumer_a_partition,) = consumer_a.tell()
(consumer_b_partition,) = consumer_b.tell()

# Pause consumer_a again.
consumer_a.pause(list(consumer_a.tell()))
# if we close consumer_a, consumer_b should get all partitions
producer.produce(next(iter(consumer_a.tell())), next(payloads)).result(
timeout=5.0
)
consumer_a.unsubscribe()
wait_until_rebalancing(consumer_a, consumer_b)

assert len(consumer_b.tell()) == 2

if self.incremental_rebalancing:

assert consumer_a_on_assign.mock_calls == [
mock.call({Partition(topic, 0): 0, Partition(topic, 1): 0}),
]
assert consumer_a_on_revoke.mock_calls == [
mock.call([Partition(topic, 0)]),
mock.call([Partition(topic, 1)]),
]

assert consumer_b_on_assign.mock_calls == [
mock.call({Partition(topic, 0): 0}),
mock.call({Partition(topic, 1): 0}),
]
assert consumer_b_on_revoke.mock_calls == []
else:
assert consumer_a_on_assign.mock_calls == [
mock.call({Partition(topic, 0): 0, Partition(topic, 1): 0}),
mock.call({consumer_a_partition: 0}),
]
assert consumer_a_on_revoke.mock_calls == [
mock.call([Partition(topic, 0), Partition(topic, 1)]),
mock.call([consumer_a_partition]),
]

assert consumer_b_on_assign.mock_calls == [
mock.call({consumer_b_partition: 0}),
mock.call({Partition(topic, 0): 0, Partition(topic, 1): 0}),
]
assert consumer_b_on_revoke.mock_calls == [
mock.call([consumer_b_partition])
]
27 changes: 21 additions & 6 deletions tests/backends/test_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

from arroyo.backends.kafka import KafkaConsumer, KafkaPayload, KafkaProducer
from arroyo.backends.kafka.commit import CommitCodec
from arroyo.backends.kafka.configuration import build_kafka_configuration
from arroyo.backends.kafka.configuration import (
KafkaBrokerConfig,
build_kafka_configuration,
)
from arroyo.backends.kafka.consumer import as_kafka_configuration_bool
from arroyo.commit import IMMEDIATE, Commit
from arroyo.errors import ConsumerError, EndOfPartition
Expand Down Expand Up @@ -71,10 +74,16 @@ def get_topic(


class TestKafkaStreams(StreamsTestMixin[KafkaPayload]):
@property
def configuration(self) -> KafkaBrokerConfig:
config = {
"bootstrap.servers": os.environ.get("DEFAULT_BROKERS", "localhost:9092"),
}

configuration = build_kafka_configuration(
{"bootstrap.servers": os.environ.get("DEFAULT_BROKERS", "localhost:9092")}
)
if self.incremental_rebalancing:
config["partition.assignment.strategy"] = "cooperative-sticky"

return build_kafka_configuration(config)

@contextlib.contextmanager
def get_topic(self, partitions: int = 1) -> Iterator[Topic]:
Expand All @@ -90,7 +99,7 @@ def get_consumer(
enable_end_of_partition: bool = True,
auto_offset_reset: str = "earliest",
strict_offset_reset: Optional[bool] = None,
max_poll_interval_ms: Optional[int] = None
max_poll_interval_ms: Optional[int] = None,
) -> KafkaConsumer:
configuration = {
**self.configuration,
Expand Down Expand Up @@ -210,7 +219,9 @@ def test_consumer_polls_when_paused(self) -> None:
poll_interval = 6000

with self.get_topic() as topic:
with closing(self.get_producer()) as producer, closing(self.get_consumer(max_poll_interval_ms=poll_interval)) as consumer:
with closing(self.get_producer()) as producer, closing(
self.get_consumer(max_poll_interval_ms=poll_interval)
) as consumer:
producer.produce(topic, next(self.get_payloads())).result(5.0)

processor = StreamProcessor(consumer, topic, factory, IMMEDIATE)
Expand Down Expand Up @@ -245,6 +256,10 @@ def test_consumer_polls_when_paused(self) -> None:
assert consumer.paused() == []


class TestKafkaStreamsIncrementalRebalancing(TestKafkaStreams):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused?

incremental_rebalancing = True


def test_commit_codec() -> None:
commit = Commit(
"group", Partition(Topic("topic"), 0), 0, time.time(), time.time() - 5
Expand Down
12 changes: 11 additions & 1 deletion tests/processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_stream_processor_lifecycle() -> None:

# The processor should accept heartbeat messages without an assignment or
# active processor.
consumer.tell.return_value = {}
consumer.poll.return_value = None
processor._run_once()

Expand Down Expand Up @@ -166,6 +167,7 @@ def test_stream_processor_termination_on_error() -> None:
offset = 0
now = datetime.now()

consumer.tell.return_value = {}
consumer.poll.return_value = BrokerValue(0, partition, offset, now)

exception = NotImplementedError("error")
Expand Down Expand Up @@ -199,6 +201,7 @@ def test_stream_processor_invalid_message_from_poll() -> None:
offset = 1
now = datetime.now()

consumer.tell.return_value = {}
consumer.poll.side_effect = [BrokerValue(0, partition, offset, now)]

strategy = mock.Mock()
Expand Down Expand Up @@ -236,6 +239,7 @@ def test_stream_processor_invalid_message_from_submit() -> None:
offset = 1
now = datetime.now()

consumer.tell.return_value = {}
consumer.poll.side_effect = [
BrokerValue(0, partition, offset, now),
BrokerValue(1, partition, offset + 1, now),
Expand Down Expand Up @@ -283,6 +287,7 @@ def test_stream_processor_create_with_partitions() -> None:
topic = Topic("topic")

consumer = mock.Mock()
consumer.tell.return_value = {}
strategy = mock.Mock()
factory = mock.Mock()
factory.create_with_partitions.return_value = strategy
Expand All @@ -306,13 +311,15 @@ def test_stream_processor_create_with_partitions() -> None:
assert factory.create_with_partitions.call_count == 1
assert create_args[1] == offsets_p0

consumer.tell.return_value = {**offsets_p0}

# Second partition assigned
offsets_p1 = {Partition(topic, 1): 0}
assignment_callback(offsets_p1)

create_args, _ = factory.create_with_partitions.call_args
assert factory.create_with_partitions.call_count == 2
assert create_args[1] == offsets_p1
assert create_args[1] == {**offsets_p1, **offsets_p0}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this test change related to your other changes? since there's no cooperative rebalancing here, seems like the assertions should stay the same?


processor._run_once()

Expand Down Expand Up @@ -376,6 +383,7 @@ def run_commit_policy_test(
) -> Sequence[int]:
commit = mock.Mock()
consumer = mock.Mock()
consumer.tell.return_value = {}
consumer.commit_offsets = commit

factory = CommitOffsetsFactory()
Expand Down Expand Up @@ -551,6 +559,7 @@ def test_dlq() -> None:
partition = Partition(topic, 0)
consumer = mock.Mock()
consumer.poll.return_value = BrokerValue(0, partition, 1, datetime.now())
consumer.tell.return_value = {}
strategy = mock.Mock()
strategy.submit.side_effect = InvalidMessage(partition, 1)
factory = mock.Mock()
Expand Down Expand Up @@ -585,6 +594,7 @@ def test_healthcheck(tmpdir: py.path.local) -> None:
consumer = mock.Mock()
now = datetime.now()
consumer.poll.return_value = BrokerValue(0, partition, 1, now)
consumer.tell.return_value = {}
strategy = mock.Mock()
strategy.submit.side_effect = InvalidMessage(partition, 1)
factory = mock.Mock()
Expand Down
Loading
Loading