diff --git a/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/TkmsStorageToKafkaProxy.java b/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/TkmsStorageToKafkaProxy.java index fb72a50..380197c 100644 --- a/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/TkmsStorageToKafkaProxy.java +++ b/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/TkmsStorageToKafkaProxy.java @@ -16,6 +16,7 @@ import com.transferwise.kafka.tkms.api.ITkmsMessageInterceptors; import com.transferwise.kafka.tkms.api.TkmsShardPartition; import com.transferwise.kafka.tkms.config.ITkmsDaoProvider; +import com.transferwise.kafka.tkms.config.ITkmsKafkaProducerPostProcessor; import com.transferwise.kafka.tkms.config.ITkmsKafkaProducerProvider; import com.transferwise.kafka.tkms.config.ITkmsKafkaProducerProvider.UseCase; import com.transferwise.kafka.tkms.config.TkmsProperties; @@ -43,6 +44,7 @@ import org.apache.commons.lang3.mutable.MutableLong; import org.apache.commons.lang3.mutable.MutableObject; import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.ProducerRecord; import org.apache.kafka.clients.producer.RecordMetadata; import org.apache.kafka.common.errors.InterruptException; @@ -79,6 +81,8 @@ public class TkmsStorageToKafkaProxy implements GracefulShutdownStrategy, ITkmsS @Autowired private ITkmsMessageInterceptors messageIntereceptors; @Autowired + private ITkmsKafkaProducerPostProcessor tkmsKafkaProducerPostProcessor; + @Autowired private SharedReentrantLockBuilderFactory lockBuilderFactory; @Autowired private ITkmsInterrupterService tkmsInterrupterService; @@ -166,7 +170,7 @@ private void poll(Control control, TkmsShardPartition shardPartition) { } } - private void poll0(Control control, TkmsShardPartition shardPartition, KafkaProducer kafkaProducer) { + private void poll0(Control control, TkmsShardPartition shardPartition, Producer kafkaProducer) { int pollerBatchSize = properties.getPollerBatchSize(shardPartition.getShard()); long startTimeMs = System.currentTimeMillis(); diff --git a/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/config/ITkmsKafkaProducerPostProcessor.java b/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/config/ITkmsKafkaProducerPostProcessor.java new file mode 100644 index 0000000..0b4b0ed --- /dev/null +++ b/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/config/ITkmsKafkaProducerPostProcessor.java @@ -0,0 +1,7 @@ +package com.transferwise.kafka.tkms.config; + +import java.util.function.Function; +import org.apache.kafka.clients.producer.Producer; + +public interface ITkmsKafkaProducerPostProcessor extends Function, Producer> { +} diff --git a/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/config/ITkmsKafkaProducerProvider.java b/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/config/ITkmsKafkaProducerProvider.java index 3473305..3402902 100644 --- a/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/config/ITkmsKafkaProducerProvider.java +++ b/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/config/ITkmsKafkaProducerProvider.java @@ -1,13 +1,19 @@ package com.transferwise.kafka.tkms.config; import com.transferwise.kafka.tkms.api.TkmsShardPartition; -import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; public interface ITkmsKafkaProducerProvider { - KafkaProducer getKafkaProducer(TkmsShardPartition tkmsShardPartition, UseCase useCase); + Producer getKafkaProducer(TkmsShardPartition tkmsShardPartition, UseCase useCase); - KafkaProducer getKafkaProducerForTopicValidation(TkmsShardPartition shardPartition); + Producer getKafkaProducerForTopicValidation(TkmsShardPartition shardPartition); + + default void addPostProcessor(ITkmsKafkaProducerPostProcessor postProcessor) { + } + + default void removePostProcessors() { + } void closeKafkaProducer(TkmsShardPartition tkmsShardPartition, UseCase useCase); diff --git a/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/config/TkmsKafkaProducerProvider.java b/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/config/TkmsKafkaProducerProvider.java index 21c98c4..193e4a2 100644 --- a/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/config/TkmsKafkaProducerProvider.java +++ b/tw-tkms-starter/src/main/java/com/transferwise/kafka/tkms/config/TkmsKafkaProducerProvider.java @@ -6,7 +6,9 @@ import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.binder.kafka.KafkaClientMetrics; import java.time.Duration; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -16,10 +18,12 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.tuple.Pair; import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.common.serialization.ByteArraySerializer; import org.apache.kafka.common.serialization.StringSerializer; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.util.Assert; @Slf4j public class TkmsKafkaProducerProvider implements ITkmsKafkaProducerProvider, GracefulShutdownStrategy { @@ -39,8 +43,21 @@ public class TkmsKafkaProducerProvider implements ITkmsKafkaProducerProvider, Gr private Map, ProducerEntry> producers = new ConcurrentHashMap<>(); + private List postProcessors = new ArrayList<>(); + + @Override + public void addPostProcessor(ITkmsKafkaProducerPostProcessor postProcessor) { + Assert.notNull(postProcessor, "'postProcessor' cannot be null"); + this.postProcessors.add(postProcessor); + } + + @Override + public void removePostProcessors() { + this.postProcessors.clear(); + } + @Override - public KafkaProducer getKafkaProducer(TkmsShardPartition shardPartition, UseCase useCase) { + public Producer getKafkaProducer(TkmsShardPartition shardPartition, UseCase useCase) { return producers.computeIfAbsent(Pair.of(shardPartition, useCase), key -> { Map configs = new HashMap<>(); @@ -84,7 +101,7 @@ public KafkaProducer getKafkaProducer(TkmsShardPartition shardPa } } - final var producer = new KafkaProducer(configs); + final var producer = getKafkaProducer(configs); final var kafkaClientMetrics = new KafkaClientMetrics(producer); kafkaClientMetrics.bindTo(meterRegistry); @@ -92,8 +109,16 @@ public KafkaProducer getKafkaProducer(TkmsShardPartition shardPa }).getProducer(); } + private Producer getKafkaProducer(Map configs) { + Producer producer = new KafkaProducer<>(configs); + for (ITkmsKafkaProducerPostProcessor pp : this.postProcessors) { + producer = pp.apply(producer); + } + return producer; + } + @Override - public KafkaProducer getKafkaProducerForTopicValidation(TkmsShardPartition shardPartition) { + public Producer getKafkaProducerForTopicValidation(TkmsShardPartition shardPartition) { return getKafkaProducer(TkmsShardPartition.of(shardPartition.getShard(), 0), UseCase.TOPIC_VALIDATION); } @@ -139,7 +164,7 @@ public boolean canShutdown() { @Accessors(chain = true) protected static class ProducerEntry { - private KafkaProducer producer; + private Producer producer; private KafkaClientMetrics kafkaClientMetric; } diff --git a/tw-tkms-starter/src/test/java/com/transferwise/kafka/tkms/MessagePostProcessingTest.java b/tw-tkms-starter/src/test/java/com/transferwise/kafka/tkms/MessagePostProcessingTest.java new file mode 100644 index 0000000..756c7b2 --- /dev/null +++ b/tw-tkms-starter/src/test/java/com/transferwise/kafka/tkms/MessagePostProcessingTest.java @@ -0,0 +1,70 @@ +package com.transferwise.kafka.tkms; + +import static com.transferwise.kafka.tkms.test.TestKafkaProducerPostProcessor.TEST_MESSAGE; +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.transferwise.common.baseutils.transactionsmanagement.ITransactionsHelper; +import com.transferwise.kafka.tkms.api.ITransactionalKafkaMessageSender; +import com.transferwise.kafka.tkms.api.TkmsMessage; +import com.transferwise.kafka.tkms.test.BaseIntTest; +import com.transferwise.kafka.tkms.test.ITkmsSentMessagesCollector; +import com.transferwise.kafka.tkms.test.TestProperties; +import java.nio.charset.StandardCharsets; +import java.util.stream.StreamSupport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.springframework.beans.factory.annotation.Autowired; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class MessagePostProcessingTest extends BaseIntTest { + + @Autowired + private TransactionalKafkaMessageSender transactionalKafkaMessageSender; + + @Autowired + private TestProperties testProperties; + + @Autowired + private ITransactionsHelper transactionsHelper; + + @BeforeEach + void setupTest() { + tkmsSentMessagesCollector.clear(); + } + + @AfterEach + void cleanupTest() { + tkmsSentMessagesCollector.clear(); + } + + @Test + void messagesAreInstrumentedWithProducerPostProcessor() { + byte[] someValue = TEST_MESSAGE; + + String topic = testProperties.getTestTopic(); + + transactionsHelper.withTransaction().run(() -> + transactionalKafkaMessageSender.sendMessages(new ITransactionalKafkaMessageSender.SendMessagesRequest() + .addTkmsMessage(new TkmsMessage().setTopic(topic).setKey("1").setValue(someValue)) + .addTkmsMessage(new TkmsMessage().setTopic(topic).setKey("2").setValue(someValue)) + )); + + await().until(() -> tkmsSentMessagesCollector.getSentMessages(topic).size() == 2); + var messages = tkmsSentMessagesCollector.getSentMessages(topic); + + assertEquals(2, messages.size()); + checkForHeader(messages.get(0), "wrapTest", "wrapped"); + checkForHeader(messages.get(1), "wrapTest", "wrapped"); + } + + private void checkForHeader(ITkmsSentMessagesCollector.SentMessage sentMessage, String key, String value) { + assertTrue( + StreamSupport.stream(sentMessage.getProducerRecord().headers().spliterator(), false) + .anyMatch(h -> h.key().equals(key) && value.equals(new String(h.value(), StandardCharsets.UTF_8))) + ); + } +} diff --git a/tw-tkms-starter/src/test/java/com/transferwise/kafka/tkms/config/TkmsKafkaProducerProviderTestServer.java b/tw-tkms-starter/src/test/java/com/transferwise/kafka/tkms/config/TkmsKafkaProducerProviderTestServer.java index cdbf3b2..ca20b78 100644 --- a/tw-tkms-starter/src/test/java/com/transferwise/kafka/tkms/config/TkmsKafkaProducerProviderTestServer.java +++ b/tw-tkms-starter/src/test/java/com/transferwise/kafka/tkms/config/TkmsKafkaProducerProviderTestServer.java @@ -6,7 +6,9 @@ import com.transferwise.kafka.tkms.config.ITkmsKafkaProducerProvider.UseCase; import com.transferwise.kafka.tkms.test.BaseIntTest; import java.lang.reflect.Field; -import org.apache.kafka.clients.producer.KafkaProducer; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Proxy; +import org.apache.kafka.clients.producer.Producer; import org.apache.kafka.clients.producer.ProducerConfig; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; @@ -16,13 +18,20 @@ class TkmsKafkaProducerProviderTestServer extends BaseIntTest { @Autowired private ITkmsKafkaProducerProvider tkmsKafkaProducerProvider; + @Test void shardKafkaPropertiesAreApplied() throws Exception { - KafkaProducer kafkaProducer = tkmsKafkaProducerProvider.getKafkaProducer(TkmsShardPartition.of(1, 0), UseCase.PROXY); + Producer kafkaProducer = tkmsKafkaProducerProvider.getKafkaProducer(TkmsShardPartition.of(1, 0), UseCase.PROXY); + + InvocationHandler handler = Proxy.getInvocationHandler(kafkaProducer); + + Field originalProducerFiled = handler.getClass().getDeclaredField("producer"); + originalProducerFiled.setAccessible(true); + Object originalProducer = originalProducerFiled.get(handler); - Field producerConfigField = kafkaProducer.getClass().getDeclaredField("producerConfig"); + Field producerConfigField = originalProducer.getClass().getDeclaredField("producerConfig"); producerConfigField.setAccessible(true); - ProducerConfig producerConfig = (ProducerConfig) producerConfigField.get(kafkaProducer); + ProducerConfig producerConfig = (ProducerConfig) producerConfigField.get(originalProducer); assertThat(producerConfig.getLong("linger.ms")).isEqualTo(7L); } diff --git a/tw-tkms-starter/src/test/java/com/transferwise/kafka/tkms/test/TestKafkaProducerPostProcessor.java b/tw-tkms-starter/src/test/java/com/transferwise/kafka/tkms/test/TestKafkaProducerPostProcessor.java new file mode 100644 index 0000000..6288500 --- /dev/null +++ b/tw-tkms-starter/src/test/java/com/transferwise/kafka/tkms/test/TestKafkaProducerPostProcessor.java @@ -0,0 +1,76 @@ +package com.transferwise.kafka.tkms.test; + +import com.transferwise.kafka.tkms.config.ITkmsKafkaProducerPostProcessor; +import com.transferwise.kafka.tkms.config.TkmsKafkaProducerProvider; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +@Component +public class TestKafkaProducerPostProcessor implements ITkmsKafkaProducerPostProcessor, InitializingBean { + + public static final byte[] TEST_MESSAGE = "Testing ProducerPostProcessing".getBytes(StandardCharsets.UTF_8); + + private ProxyInvocationHandler handler; + + @Autowired + TkmsKafkaProducerProvider tkmsKafkaProducerProvider; + + @SuppressWarnings("unchecked") + @Override + public Producer apply(Producer producer) { + handler = new ProxyInvocationHandler(producer); + return (Producer) + Proxy.newProxyInstance( + TestKafkaProducerPostProcessor.class.getClassLoader(), + new Class[] {Producer.class}, + handler); + } + + @Override + public void afterPropertiesSet() throws Exception { + tkmsKafkaProducerProvider.addPostProcessor(this); + } + + private static class ProxyInvocationHandler implements InvocationHandler { + + private final Producer producer; + + public ProxyInvocationHandler(Producer producer) { + this.producer = producer; + } + + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + if ("send".equals(method.getName()) + && method.getParameterCount() >= 1 + && method.getParameterTypes()[0] == ProducerRecord.class) { + ProducerRecord record = (ProducerRecord) args[0]; + if (Arrays.equals(TEST_MESSAGE, record.value())) { + record.headers().add("wrapTest", "wrapped".getBytes(StandardCharsets.UTF_8)); + } + Callback callback = + method.getParameterCount() >= 2 + && method.getParameterTypes()[1] == Callback.class + ? (Callback) args[1] + : null; + return producer.send(record, callback); + } else { + try { + return method.invoke(producer, args); + } catch (InvocationTargetException exception) { + throw exception.getCause(); + } + } + } + } +} \ No newline at end of file