Skip to content

Commit

Permalink
add KafkaProducer post-processor
Browse files Browse the repository at this point in the history
  • Loading branch information
sszp committed Sep 5, 2024
1 parent 6b08a17 commit e376f5d
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.transferwise.kafka.tkms.api.ITkmsMessageInterceptors;
import com.transferwise.kafka.tkms.api.TkmsShardPartition;
import com.transferwise.kafka.tkms.config.ITkmsDaoProvider;
import com.transferwise.kafka.tkms.config.ITkmsKafkaProducerPostProcessor;
import com.transferwise.kafka.tkms.config.ITkmsKafkaProducerProvider;
import com.transferwise.kafka.tkms.config.ITkmsKafkaProducerProvider.UseCase;
import com.transferwise.kafka.tkms.config.TkmsProperties;
Expand Down Expand Up @@ -43,6 +44,7 @@
import org.apache.commons.lang3.mutable.MutableLong;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.Producer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.apache.kafka.common.errors.InterruptException;
Expand Down Expand Up @@ -79,6 +81,8 @@ public class TkmsStorageToKafkaProxy implements GracefulShutdownStrategy, ITkmsS
@Autowired
private ITkmsMessageInterceptors messageIntereceptors;
@Autowired
private ITkmsKafkaProducerPostProcessor tkmsKafkaProducerPostProcessor;
@Autowired
private SharedReentrantLockBuilderFactory lockBuilderFactory;
@Autowired
private ITkmsInterrupterService tkmsInterrupterService;
Expand Down Expand Up @@ -166,7 +170,7 @@ private void poll(Control control, TkmsShardPartition shardPartition) {
}
}

private void poll0(Control control, TkmsShardPartition shardPartition, KafkaProducer<String, byte[]> kafkaProducer) {
private void poll0(Control control, TkmsShardPartition shardPartition, Producer<String, byte[]> kafkaProducer) {

int pollerBatchSize = properties.getPollerBatchSize(shardPartition.getShard());
long startTimeMs = System.currentTimeMillis();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.transferwise.kafka.tkms.config;

import java.util.function.Function;
import org.apache.kafka.clients.producer.Producer;

public interface ITkmsKafkaProducerPostProcessor extends Function<Producer<String, byte[]>, Producer<String, byte[]>> {
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package com.transferwise.kafka.tkms.config;

import com.transferwise.kafka.tkms.api.TkmsShardPartition;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.Producer;

public interface ITkmsKafkaProducerProvider {

KafkaProducer<String, byte[]> getKafkaProducer(TkmsShardPartition tkmsShardPartition, UseCase useCase);
Producer<String, byte[]> getKafkaProducer(TkmsShardPartition tkmsShardPartition, UseCase useCase);

KafkaProducer<String, byte[]> getKafkaProducerForTopicValidation(TkmsShardPartition shardPartition);
Producer<String, byte[]> getKafkaProducerForTopicValidation(TkmsShardPartition shardPartition);

default void addPostProcessor(ITkmsKafkaProducerPostProcessor postProcessor) {
}

default void removePostProcessors() {
}

void closeKafkaProducer(TkmsShardPartition tkmsShardPartition, UseCase useCase);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.binder.kafka.KafkaClientMetrics;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -16,10 +18,12 @@
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.Producer;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.common.serialization.ByteArraySerializer;
import org.apache.kafka.common.serialization.StringSerializer;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.util.Assert;

@Slf4j
public class TkmsKafkaProducerProvider implements ITkmsKafkaProducerProvider, GracefulShutdownStrategy {
Expand All @@ -39,8 +43,21 @@ public class TkmsKafkaProducerProvider implements ITkmsKafkaProducerProvider, Gr

private Map<Pair<TkmsShardPartition, UseCase>, ProducerEntry> producers = new ConcurrentHashMap<>();

private List<ITkmsKafkaProducerPostProcessor> postProcessors = new ArrayList<>();

@Override
public void addPostProcessor(ITkmsKafkaProducerPostProcessor postProcessor) {
Assert.notNull(postProcessor, "'postProcessor' cannot be null");
this.postProcessors.add(postProcessor);
}

@Override
public void removePostProcessors() {
this.postProcessors.clear();
}

@Override
public KafkaProducer<String, byte[]> getKafkaProducer(TkmsShardPartition shardPartition, UseCase useCase) {
public Producer<String, byte[]> getKafkaProducer(TkmsShardPartition shardPartition, UseCase useCase) {
return producers.computeIfAbsent(Pair.of(shardPartition, useCase), key -> {
Map<String, Object> configs = new HashMap<>();

Expand Down Expand Up @@ -84,16 +101,24 @@ public KafkaProducer<String, byte[]> getKafkaProducer(TkmsShardPartition shardPa
}
}

final var producer = new KafkaProducer<String, byte[]>(configs);
final var producer = getKafkaProducer(configs);
final var kafkaClientMetrics = new KafkaClientMetrics(producer);
kafkaClientMetrics.bindTo(meterRegistry);

return new ProducerEntry().setProducer(producer).setKafkaClientMetric(kafkaClientMetrics);
}).getProducer();
}

private Producer<String, byte[]> getKafkaProducer(Map<String, Object> configs) {
Producer<String, byte[]> producer = new KafkaProducer<>(configs);
for (ITkmsKafkaProducerPostProcessor pp : this.postProcessors) {
producer = pp.apply(producer);
}
return producer;
}

@Override
public KafkaProducer<String, byte[]> getKafkaProducerForTopicValidation(TkmsShardPartition shardPartition) {
public Producer<String, byte[]> getKafkaProducerForTopicValidation(TkmsShardPartition shardPartition) {
return getKafkaProducer(TkmsShardPartition.of(shardPartition.getShard(), 0), UseCase.TOPIC_VALIDATION);
}

Expand Down Expand Up @@ -139,7 +164,7 @@ public boolean canShutdown() {
@Accessors(chain = true)
protected static class ProducerEntry {

private KafkaProducer<String, byte[]> producer;
private Producer<String, byte[]> producer;

private KafkaClientMetrics kafkaClientMetric;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package com.transferwise.kafka.tkms;

import static com.transferwise.kafka.tkms.test.TestKafkaProducerPostProcessor.TEST_MESSAGE;
import static org.awaitility.Awaitility.await;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.transferwise.common.baseutils.transactionsmanagement.ITransactionsHelper;
import com.transferwise.kafka.tkms.api.ITransactionalKafkaMessageSender;
import com.transferwise.kafka.tkms.api.TkmsMessage;
import com.transferwise.kafka.tkms.test.BaseIntTest;
import com.transferwise.kafka.tkms.test.ITkmsSentMessagesCollector;
import com.transferwise.kafka.tkms.test.TestProperties;
import java.nio.charset.StandardCharsets;
import java.util.stream.StreamSupport;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.springframework.beans.factory.annotation.Autowired;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class MessagePostProcessingTest extends BaseIntTest {

@Autowired
private TransactionalKafkaMessageSender transactionalKafkaMessageSender;

@Autowired
private TestProperties testProperties;

@Autowired
private ITransactionsHelper transactionsHelper;

@BeforeEach
void setupTest() {
tkmsSentMessagesCollector.clear();
}

@AfterEach
void cleanupTest() {
tkmsSentMessagesCollector.clear();
}

@Test
void messagesAreInstrumentedWithProducerPostProcessor() {
byte[] someValue = TEST_MESSAGE;

String topic = testProperties.getTestTopic();

transactionsHelper.withTransaction().run(() ->
transactionalKafkaMessageSender.sendMessages(new ITransactionalKafkaMessageSender.SendMessagesRequest()
.addTkmsMessage(new TkmsMessage().setTopic(topic).setKey("1").setValue(someValue))
.addTkmsMessage(new TkmsMessage().setTopic(topic).setKey("2").setValue(someValue))
));

await().until(() -> tkmsSentMessagesCollector.getSentMessages(topic).size() == 2);
var messages = tkmsSentMessagesCollector.getSentMessages(topic);

assertEquals(2, messages.size());
checkForHeader(messages.get(0), "wrapTest", "wrapped");
checkForHeader(messages.get(1), "wrapTest", "wrapped");
}

private void checkForHeader(ITkmsSentMessagesCollector.SentMessage sentMessage, String key, String value) {
assertTrue(
StreamSupport.stream(sentMessage.getProducerRecord().headers().spliterator(), false)
.anyMatch(h -> h.key().equals(key) && value.equals(new String(h.value(), StandardCharsets.UTF_8)))
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import com.transferwise.kafka.tkms.config.ITkmsKafkaProducerProvider.UseCase;
import com.transferwise.kafka.tkms.test.BaseIntTest;
import java.lang.reflect.Field;
import org.apache.kafka.clients.producer.KafkaProducer;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Proxy;
import org.apache.kafka.clients.producer.Producer;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -16,13 +18,20 @@ class TkmsKafkaProducerProviderTestServer extends BaseIntTest {
@Autowired
private ITkmsKafkaProducerProvider tkmsKafkaProducerProvider;


@Test
void shardKafkaPropertiesAreApplied() throws Exception {
KafkaProducer<String, byte[]> kafkaProducer = tkmsKafkaProducerProvider.getKafkaProducer(TkmsShardPartition.of(1, 0), UseCase.PROXY);
Producer<String, byte[]> kafkaProducer = tkmsKafkaProducerProvider.getKafkaProducer(TkmsShardPartition.of(1, 0), UseCase.PROXY);

InvocationHandler handler = Proxy.getInvocationHandler(kafkaProducer);

Field originalProducerFiled = handler.getClass().getDeclaredField("producer");
originalProducerFiled.setAccessible(true);
Object originalProducer = originalProducerFiled.get(handler);

Field producerConfigField = kafkaProducer.getClass().getDeclaredField("producerConfig");
Field producerConfigField = originalProducer.getClass().getDeclaredField("producerConfig");
producerConfigField.setAccessible(true);
ProducerConfig producerConfig = (ProducerConfig) producerConfigField.get(kafkaProducer);
ProducerConfig producerConfig = (ProducerConfig) producerConfigField.get(originalProducer);

assertThat(producerConfig.getLong("linger.ms")).isEqualTo(7L);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.transferwise.kafka.tkms.test;

import com.transferwise.kafka.tkms.config.ITkmsKafkaProducerPostProcessor;
import com.transferwise.kafka.tkms.config.TkmsKafkaProducerProvider;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.Producer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

@Component
public class TestKafkaProducerPostProcessor implements ITkmsKafkaProducerPostProcessor, InitializingBean {

public static final byte[] TEST_MESSAGE = "Testing ProducerPostProcessing".getBytes(StandardCharsets.UTF_8);

private ProxyInvocationHandler handler;

@Autowired
TkmsKafkaProducerProvider tkmsKafkaProducerProvider;

@SuppressWarnings("unchecked")
@Override
public Producer<String, byte[]> apply(Producer<String, byte[]> producer) {
handler = new ProxyInvocationHandler(producer);
return (Producer<String, byte[]>)
Proxy.newProxyInstance(
TestKafkaProducerPostProcessor.class.getClassLoader(),
new Class<?>[] {Producer.class},
handler);
}

@Override
public void afterPropertiesSet() throws Exception {
tkmsKafkaProducerProvider.addPostProcessor(this);
}

private static class ProxyInvocationHandler implements InvocationHandler {

private final Producer<String, byte[]> producer;

public ProxyInvocationHandler(Producer<String, byte[]> producer) {
this.producer = producer;
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if ("send".equals(method.getName())
&& method.getParameterCount() >= 1
&& method.getParameterTypes()[0] == ProducerRecord.class) {
ProducerRecord<String, byte[]> record = (ProducerRecord<String, byte[]>) args[0];
if (Arrays.equals(TEST_MESSAGE, record.value())) {
record.headers().add("wrapTest", "wrapped".getBytes(StandardCharsets.UTF_8));
}
Callback callback =
method.getParameterCount() >= 2
&& method.getParameterTypes()[1] == Callback.class
? (Callback) args[1]
: null;
return producer.send(record, callback);
} else {
try {
return method.invoke(producer, args);
} catch (InvocationTargetException exception) {
throw exception.getCause();
}
}
}
}
}

0 comments on commit e376f5d

Please sign in to comment.