Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Initial support for cooperative-sticky rebalancing #407

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions arroyo/backends/kafka/consumer.py
Original file line number Diff line number Diff line change
@@ -161,6 +161,10 @@ def __init__(
)

configuration = dict(configuration)
self.__is_incremental = (
configuration.get("partition.assignment.strategy") == "cooperative-sticky"
or configuration.get("group.protocol") == "consumer"
)
auto_offset_reset = configuration.get("auto.offset.reset", "largest")

# This is a special flag that controls the auto offset behavior for
@@ -269,6 +273,10 @@ def subscribe(
def assignment_callback(
consumer: ConfluentConsumer, partitions: Sequence[ConfluentTopicPartition]
) -> None:
if not partitions:
logger.info("skipping empty assignment")
return
Comment on lines +277 to +278
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need a different logic for partition assignment between cooperative and standard in case of empty assignment?
I assume you can get an empty assignment in the cooperative rebalancing when, after a rebalancing, your assignments do not change. Is that the scenario where you do not want to touch the existing assignments ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after sleeping on it i agree. i only added this because it made cooperative rebalancing more comprehensible, and wasn't sure of the implications on regular rebalancing. i think we can skip empty assignments regardless of the assignment strategy.


self.__state = KafkaConsumerState.ASSIGNING

try:
@@ -451,12 +459,14 @@ def __validate_offsets(self, offsets: Mapping[Partition, int]) -> None:

def __assign(self, offsets: Mapping[Partition, int]) -> None:
self.__validate_offsets(offsets)
self.__consumer.assign(
[
ConfluentTopicPartition(partition.topic.name, partition.index, offset)
for partition, offset in offsets.items()
]
)
partitions = [
ConfluentTopicPartition(partition.topic.name, partition.index, offset)
for partition, offset in offsets.items()
]
if self.__is_incremental:
self.__consumer.incremental_assign(partitions)
else:
self.__consumer.assign(partitions)
self.__offsets.update(offsets)

def seek(self, offsets: Mapping[Partition, int]) -> None:
10 changes: 5 additions & 5 deletions arroyo/dlq.py
Original file line number Diff line number Diff line change
@@ -315,11 +315,11 @@ def pop(

return None

def reset(self) -> None:
def remove(self, partition: Partition) -> None:
"""
Reset the buffer.
Remove a revoked partition from the buffer.
"""
self.__buffered_messages = defaultdict(deque)
self.__buffered_messages.pop(partition, None)


class DlqPolicyWrapper(Generic[TStrategyPayload]):
@@ -343,9 +343,9 @@ def __init__(
]
],
] = defaultdict(deque)
self.reset_offsets({})
self.reset_dlq_limits({})

def reset_offsets(self, assignment: Mapping[Partition, int]) -> None:
def reset_dlq_limits(self, assignment: Mapping[Partition, int]) -> None:
"""
Called on consumer assignment
"""
20 changes: 15 additions & 5 deletions arroyo/processing/processor.py
Original file line number Diff line number Diff line change
@@ -238,16 +238,23 @@ def on_partitions_assigned(partitions: Mapping[Partition, int]) -> None:
"arroyo.consumer.partitions_assigned.count", len(partitions)
)

self.__buffered_messages.reset()
current_partitions = dict(self.__consumer.tell())
current_partitions.update(partitions)

if self.__dlq_policy:
self.__dlq_policy.reset_offsets(partitions)
if partitions:
self.__dlq_policy.reset_dlq_limits(current_partitions)
if current_partitions:
if self.__processing_strategy is not None:
logger.exception(
# TODO: for cooperative-sticky rebalancing this can happen
# quite often. we should port the changes to
# ProcessingStrategyFactory that we made in Rust: Remove
# create_with_partitions, replace with create +
# update_partitions
logger.warning(
"Partition assignment while processing strategy active"
)
_close_strategy()
_create_strategy(partitions)
_create_strategy(current_partitions)

@_rdkafka_callback(metrics=self.__metrics_buffer)
def on_partitions_revoked(partitions: Sequence[Partition]) -> None:
@@ -278,6 +285,9 @@ def on_partitions_revoked(partitions: Sequence[Partition]) -> None:
except RuntimeError:
pass

for partition in partitions:
self.__buffered_messages.remove(partition)

# Partition revocation can happen anytime during the consumer lifecycle and happen
# multiple times. What we want to know is that the consumer is not stuck somewhere.
# The presence of this message as the last message of a consumer
99 changes: 84 additions & 15 deletions tests/backends/mixins.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import uuid
from abc import ABC, abstractmethod
from contextlib import closing
from typing import ContextManager, Generic, Iterator, Mapping, Optional, Sequence
from typing import Any, ContextManager, Generic, Iterator, Mapping, Optional, Sequence
from unittest import mock

import pytest
@@ -14,6 +14,8 @@


class StreamsTestMixin(ABC, Generic[TStrategyPayload]):
cooperative_sticky = False

@abstractmethod
def get_topic(self, partitions: int = 1) -> ContextManager[Topic]:
raise NotImplementedError
@@ -397,6 +399,11 @@ def test_pause_resume(self) -> None:
def test_pause_resume_rebalancing(self) -> None:
payloads = self.get_payloads()

consumer_a_on_assign = mock.Mock()
consumer_a_on_revoke = mock.Mock()
consumer_b_on_assign = mock.Mock()
consumer_b_on_revoke = mock.Mock()

with self.get_topic(2) as topic, closing(
self.get_producer()
) as producer, closing(
@@ -408,10 +415,22 @@ def test_pause_resume_rebalancing(self) -> None:
producer.produce(Partition(topic, i), next(payloads)).result(
timeout=5.0
)
for i in range(2)
for i in [0, 1]
]

consumer_a.subscribe([topic])
def wait_until_rebalancing(
from_consumer: Consumer[Any], to_consumer: Consumer[Any]
) -> None:
for _ in range(10):
assert from_consumer.poll(0) is None
if to_consumer.poll(1.0) is not None:
return

raise RuntimeError("no rebalancing happened")

consumer_a.subscribe(
[topic], on_assign=consumer_a_on_assign, on_revoke=consumer_a_on_revoke
)

# It doesn't really matter which message is fetched first -- we
# just want to know the assignment occurred.
@@ -428,19 +447,69 @@ def test_pause_resume_rebalancing(self) -> None:
[Partition(topic, 0), Partition(topic, 1)]
)

consumer_b.subscribe([topic])
for i in range(10):
assert consumer_a.poll(0) is None # attempt to force session timeout
if consumer_b.poll(1.0) is not None:
break
else:
assert False, "rebalance did not occur"
consumer_b.subscribe(
[topic], on_assign=consumer_b_on_assign, on_revoke=consumer_b_on_revoke
)

wait_until_rebalancing(consumer_a, consumer_b)

# The first consumer should have had its offsets rolled back, as
# well as should have had it's partition resumed during
# rebalancing.
assert consumer_a.paused() == []
assert consumer_a.poll(10.0) is not None
if self.cooperative_sticky:
# within incremental rebalancing, only one partition should have been reassigned to the consumer_b, and consumer_a should remain paused
assert consumer_a.paused() == [Partition(topic, 1)]
assert consumer_a.poll(10.0) is None
else:
# The first consumer should have had its offsets rolled back, as
# well as should have had it's partition resumed during
# rebalancing.
assert consumer_a.paused() == []
assert consumer_a.poll(10.0) is not None

assert len(consumer_a.tell()) == 1
assert len(consumer_b.tell()) == 1

(consumer_a_partition,) = consumer_a.tell()
(consumer_b_partition,) = consumer_b.tell()

# Pause consumer_a again.
consumer_a.pause(list(consumer_a.tell()))
# if we close consumer_a, consumer_b should get all partitions
producer.produce(next(iter(consumer_a.tell())), next(payloads)).result(
timeout=5.0
)
consumer_a.unsubscribe()
wait_until_rebalancing(consumer_a, consumer_b)

assert len(consumer_b.tell()) == 2

if self.cooperative_sticky:

assert consumer_a_on_assign.mock_calls == [
mock.call({Partition(topic, 0): 0, Partition(topic, 1): 0}),
]
assert consumer_a_on_revoke.mock_calls == [
mock.call([Partition(topic, 0)]),
mock.call([Partition(topic, 1)]),
]

assert consumer_b_on_assign.mock_calls == [
mock.call({Partition(topic, 0): 0}),
mock.call({Partition(topic, 1): 0}),
]
assert consumer_b_on_revoke.mock_calls == []
else:
assert consumer_a_on_assign.mock_calls == [
mock.call({Partition(topic, 0): 0, Partition(topic, 1): 0}),
mock.call({consumer_a_partition: 0}),
]
assert consumer_a_on_revoke.mock_calls == [
mock.call([Partition(topic, 0), Partition(topic, 1)]),
mock.call([consumer_a_partition]),
]

assert consumer_b_on_assign.mock_calls == [
mock.call({consumer_b_partition: 0}),
mock.call({Partition(topic, 0): 0, Partition(topic, 1): 0}),
]
assert consumer_b_on_revoke.mock_calls == [
mock.call([consumer_b_partition])
]
43 changes: 37 additions & 6 deletions tests/backends/test_kafka.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,10 @@

from arroyo.backends.kafka import KafkaConsumer, KafkaPayload, KafkaProducer
from arroyo.backends.kafka.commit import CommitCodec
from arroyo.backends.kafka.configuration import build_kafka_configuration
from arroyo.backends.kafka.configuration import (
KafkaBrokerConfig,
build_kafka_configuration,
)
from arroyo.backends.kafka.consumer import as_kafka_configuration_bool
from arroyo.commit import IMMEDIATE, Commit
from arroyo.errors import ConsumerError, EndOfPartition
@@ -56,6 +59,7 @@ def get_topic(
configuration: Mapping[str, Any], partitions_count: int
) -> Iterator[Topic]:
name = f"test-{uuid.uuid1().hex}"
configuration = dict(configuration)
client = AdminClient(configuration)
[[key, future]] = client.create_topics(
[NewTopic(name, num_partitions=partitions_count, replication_factor=1)]
@@ -71,10 +75,15 @@ def get_topic(


class TestKafkaStreams(StreamsTestMixin[KafkaPayload]):
kip_848 = False

configuration = build_kafka_configuration(
{"bootstrap.servers": os.environ.get("DEFAULT_BROKERS", "localhost:9092")}
)
@property
def configuration(self) -> KafkaBrokerConfig:
config = {
"bootstrap.servers": os.environ.get("DEFAULT_BROKERS", "localhost:9092"),
}

return build_kafka_configuration(config)

@contextlib.contextmanager
def get_topic(self, partitions: int = 1) -> Iterator[Topic]:
@@ -90,7 +99,7 @@ def get_consumer(
enable_end_of_partition: bool = True,
auto_offset_reset: str = "earliest",
strict_offset_reset: Optional[bool] = None,
max_poll_interval_ms: Optional[int] = None
max_poll_interval_ms: Optional[int] = None,
) -> KafkaConsumer:
configuration = {
**self.configuration,
@@ -110,6 +119,16 @@ def get_consumer(
if max_poll_interval_ms < 45000:
configuration["session.timeout.ms"] = max_poll_interval_ms

if self.cooperative_sticky:
configuration["partition.assignment.strategy"] = "cooperative-sticky"

if self.kip_848:
configuration["group.protocol"] = "consumer"
configuration.pop("session.timeout.ms")
configuration.pop("max.poll.interval.ms", None)
assert "group.protocol.type" not in configuration
assert "heartbeat.interval.ms" not in configuration

return KafkaConsumer(configuration)

def get_producer(self) -> KafkaProducer:
@@ -210,7 +229,9 @@ def test_consumer_polls_when_paused(self) -> None:
poll_interval = 6000

with self.get_topic() as topic:
with closing(self.get_producer()) as producer, closing(self.get_consumer(max_poll_interval_ms=poll_interval)) as consumer:
with closing(self.get_producer()) as producer, closing(
self.get_consumer(max_poll_interval_ms=poll_interval)
) as consumer:
producer.produce(topic, next(self.get_payloads())).result(5.0)

processor = StreamProcessor(consumer, topic, factory, IMMEDIATE)
@@ -245,6 +266,16 @@ def test_consumer_polls_when_paused(self) -> None:
assert consumer.paused() == []


class TestKafkaStreamsIncrementalRebalancing(TestKafkaStreams):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, this actually re-declares all the tests in TestKafkaStreams, just re-running them with cooperative-sticky rebalancing

# re-test the kafka consumer with cooperative-sticky rebalancing
cooperative_sticky = True


@pytest.mark.skip("kip-848 not functional yet")
class TestKafkaStreamsKip848(TestKafkaStreams):
kip_848 = True


def test_commit_codec() -> None:
commit = Commit(
"group", Partition(Topic("topic"), 0), 0, time.time(), time.time() - 5
Loading