Skip to content

Commit

Permalink
Customizing batch message conversion behavior for Kafka binder
Browse files Browse the repository at this point in the history
Continuation of the previous commit: 14c1046
  • Loading branch information
sobychacko committed Sep 17, 2024
1 parent 14c1046 commit fac9eb1
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
<dependency>
<groupId>org.springframework.kafka</groupId>
<artifactId>spring-kafka</artifactId>
<version>3.3.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

package org.springframework.cloud.stream.binder.kafka.config;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;

import org.springframework.cloud.function.context.config.MessageConverterHelper;
import org.springframework.kafka.support.KafkaHeaders;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;

/**
* @author Oleg Zhurakousky
* @author Soby Chacko
*/
public class DefaultMessageConverterHelper implements MessageConverterHelper {

Expand All @@ -32,10 +37,22 @@ public boolean shouldFailIfCantConvert(Message<?> message) {
}

public void postProcessBatchMessageOnFailure(Message<?> message, int index) {
AtomicInteger deliveryAttempt = (AtomicInteger) message.getHeaders().get("deliveryAttempt");
// if (message.getHeaders().containsKey("amqp_batchedHeaders") && deliveryAttempt != null && deliveryAttempt.get() == 1) {
// ArrayList<?> list = (ArrayList<?>) message.getHeaders().get("amqp_batchedHeaders");
// list.remove(index);
// }
MessageHeaders headers = message.getHeaders();
Set<String> headerKeySet = headers.keySet();
List<String> matchingHeaderKeys = new ArrayList<>();

for (String string : headerKeySet) {
if (string.startsWith(KafkaHeaders.PREFIX)) {
matchingHeaderKeys.add(string);
}
}
for (String matchingHeaderKey : matchingHeaderKeys) {
Object matchingHeaderValue = message.getHeaders().get(matchingHeaderKey);
if (matchingHeaderValue instanceof ArrayList<?> list) {
if (!list.isEmpty()) {
list.remove(index);
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2019-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.cloud.stream.binder.kafka.config;

import org.springframework.cloud.function.context.config.MessageConverterHelper;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

/**
* @author Oleg Zhurakousky
* @author Soby Chacko
*/
@Configuration(proxyBeanMethods = false)
public class MessageConverterHelperConfiguration {

@Bean
public MessageConverterHelper messageConverterHelper() {
return new DefaultMessageConverterHelper();
}
}
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
org.springframework.cloud.stream.binder.kafka.config.ExtendedBindingHandlerMappingsProviderConfiguration
org.springframework.cloud.stream.binder.kafka.config.MessageConverterHelperConfiguration
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2024 the original author or authors.
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,163 +18,81 @@

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.cloud.function.context.config.MessageConverterHelper;
import org.springframework.cloud.function.json.JacksonMapper;
import org.springframework.cloud.stream.binder.test.InputDestination;
import org.springframework.cloud.stream.binder.test.OutputDestination;
import org.springframework.cloud.stream.binder.test.TestChannelBinder;
import org.springframework.cloud.stream.binder.test.TestChannelBinderConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.cloud.stream.function.StreamBridge;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.kafka.test.context.EmbeddedKafka;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHandlingException;
import org.springframework.messaging.converter.MessageConversionException;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.util.Assert;

import static org.assertj.core.api.Assertions.assertThat;

/**
*
* @author Soby Chacko
*/
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.NONE, properties = {
"spring.cloud.function.definition=batchConsumer",
"spring.cloud.stream.bindings.batchConsumer-in-0.consumer.batch-mode=true",
"spring.cloud.stream.bindings.batchConsumer-in-0.destination=cfrthp-topic",
"spring.cloud.stream.bindings.batchConsumer-in-0.group=cfrthp-group"
})
@EmbeddedKafka
@DirtiesContext
public class FunctionBatchingConversionTests {

@SuppressWarnings("unchecked")
// @Test
void testBatchHeadersMatchingPayload() {
TestChannelBinderConfiguration.applicationContextRunner(BatchFunctionConfiguration.class)
.withPropertyValues("spring.cloud.stream.function.definition=func",
"spring.cloud.stream.bindings.func-in-0.consumer.batch-mode=true",
"spring.cloud.stream.rabbit.bindings.func-in-0.consumer.enable-batching=true")
.run(context -> {
InputDestination inputDestination = context.getBean(InputDestination.class);
OutputDestination outputDestination = context.getBean(OutputDestination.class);

List<byte[]> payloads = List.of("hello".getBytes(StandardCharsets.UTF_8),
"{\"name\":\"Ricky\"}".getBytes(StandardCharsets.UTF_8),
"{\"name\":\"Julien\"}".getBytes(StandardCharsets.UTF_8),
"{\"name\":\"Bubbles\"}".getBytes(StandardCharsets.UTF_8),
"hello".getBytes(StandardCharsets.UTF_8));
List<Map<String, String>> amqpBatchHeaders = new ArrayList<>();
for (int i = 0; i < 5; i++) {
Map<String, String> batchHeaders = new LinkedHashMap<>();
batchHeaders.put("amqp_receivedDeliveryMode", "PERSISTENT");
batchHeaders.put("index", String.valueOf(i));
amqpBatchHeaders.add(batchHeaders);
}

var message = MessageBuilder.withPayload(payloads)
.setHeader("amqp_batchedHeaders", amqpBatchHeaders)
.setHeader("deliveryAttempt", new AtomicInteger(1)).build();
inputDestination.send(message);

Message<byte[]> resultMessage = outputDestination.receive();
JacksonMapper mapper = context.getBean(JacksonMapper.class);
List<?> resultPayloads = mapper.fromJson(resultMessage.getPayload(), List.class);
assertThat(resultPayloads).hasSize(3);

List<Map<String, String>> amqpBatchedHeaders = (List<Map<String, String>>) resultMessage
.getHeaders().get("amqp_batchedHeaders");
assertThat(amqpBatchedHeaders).hasSize(resultPayloads.size());
assertThat(amqpBatchedHeaders.get(0).get("index")).isEqualTo("1");
assertThat(amqpBatchedHeaders.get(1).get("index")).isEqualTo("2");
assertThat(amqpBatchedHeaders.get(2).get("index")).isEqualTo("3");

context.stop();
});
}
@Autowired
private StreamBridge streamBridge;

// @Test
void testBatchHeadersForcingFatalFailureOnConversiionException() {
TestChannelBinderConfiguration
.applicationContextRunner(BatchFunctionConfigurationWithAdditionalConversionHelper.class)
.withPropertyValues("spring.cloud.stream.function.definition=func",
"spring.cloud.stream.bindings.func-in-0.consumer.batch-mode=true",
"spring.cloud.stream.bindings.func-in-0.consumer.max-attempts=1",
"spring.cloud.stream.rabbit.bindings.func-in-0.consumer.enable-batching=true")
.run(context -> {
InputDestination inputDestination = context.getBean(InputDestination.class);

List<byte[]> payloads = List.of("hello".getBytes(StandardCharsets.UTF_8),
"{\"name\":\"Ricky\"}".getBytes(StandardCharsets.UTF_8),
"{\"name\":\"Julien\"}".getBytes(StandardCharsets.UTF_8),
"{\"name\":\"Bubbles\"}".getBytes(StandardCharsets.UTF_8),
"hello".getBytes(StandardCharsets.UTF_8));
List<Map<String, String>> amqpBatchHeaders = new ArrayList<>();
for (int i = 0; i < 5; i++) {
Map<String, String> batchHeaders = new LinkedHashMap<>();
batchHeaders.put("amqp_receivedDeliveryMode", "PERSISTENT");
batchHeaders.put("index", String.valueOf(i));
amqpBatchHeaders.add(batchHeaders);
}

var message = MessageBuilder.withPayload(payloads)
.setHeader("amqp_batchedHeaders", amqpBatchHeaders)
.setHeader("deliveryAttempt", new AtomicInteger(1)).build();
inputDestination.send(message);
TestChannelBinder binder = context.getBean(TestChannelBinder.class);
assertThat(binder.getLastError().getPayload()).isInstanceOf(MessageHandlingException.class);
MessageHandlingException exception = (MessageHandlingException) binder.getLastError().getPayload();
assertThat(exception.getCause()).isInstanceOf(MessageConversionException.class);

context.stop();
});
}
static CountDownLatch latch = new CountDownLatch(3);

@Configuration
@EnableAutoConfiguration
public static class BatchFunctionConfiguration {
@Bean
public Function<Message<List<Person>>, Message<List<Person>>> func() {
return x -> {
return x;
};
}
static List<Person> persons = new ArrayList<>();

@Test
void conversionFailuresRemoveTheHeadersProperly() throws Exception {
streamBridge.send("cfrthp-topic", "hello".getBytes(StandardCharsets.UTF_8));
streamBridge.send("cfrthp-topic", "hello".getBytes(StandardCharsets.UTF_8));
streamBridge.send("cfrthp-topic", "{\"name\":\"Ricky\"}".getBytes(StandardCharsets.UTF_8));
streamBridge.send("cfrthp-topic", "{\"name\":\"Julian\"}".getBytes(StandardCharsets.UTF_8));
streamBridge.send("cfrthp-topic", "{\"name\":\"Bubbles\"}".getBytes(StandardCharsets.UTF_8));

Assert.isTrue(latch.await(10, TimeUnit.SECONDS), "Failed to receive message");

assertThat(persons.size()).isEqualTo(3);
assertThat(persons.get(0).toString().contains("Ricky")).isTrue();
assertThat(persons.get(1).toString().contains("Julian")).isTrue();
assertThat(persons.get(2).toString().contains("Bubbles")).isTrue();
}

@Configuration
@EnableAutoConfiguration
public static class BatchFunctionConfigurationWithAdditionalConversionHelper {
@Configuration
public static class Config {

@Bean
public MessageConverterHelper helper() {
return new MessageConverterHelper() {
public boolean shouldFailIfCantConvert(Message<?> message) {
return true;
Consumer<Message<List<Person>>> batchConsumer() {
return message -> {
if (!message.getPayload().isEmpty()) {
message.getPayload().forEach(c -> {
persons.add(c);
latch.countDown();
});
}
};
}

@Bean
public Function<Message<List<Person>>, Message<List<Person>>> func() {
return x -> {
return x;
};
}
}

static class Person {

private String name;

public String getName() {
return name;
}

public void setName(String name) {
this.name = name;
}

public String toString() {
return "name: " + name;
}
record Person(String name) {
}

}

0 comments on commit fac9eb1

Please sign in to comment.