diff --git a/spring-kafka/src/main/java/org/springframework/kafka/listener/adapter/MessagingMessageListenerAdapter.java b/spring-kafka/src/main/java/org/springframework/kafka/listener/adapter/MessagingMessageListenerAdapter.java index f1cc77b3c6..110ee6d23c 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/listener/adapter/MessagingMessageListenerAdapter.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/listener/adapter/MessagingMessageListenerAdapter.java @@ -80,6 +80,7 @@ * @author Gary Russell * @author Artem Bilan * @author Venil Noronha + * @author Nathan Xu */ public abstract class MessagingMessageListenerAdapter implements ConsumerSeekAware { @@ -466,7 +467,13 @@ private String evaluateTopic(Object request, Object source, Object result, @Null * @since 2.1.3 */ @SuppressWarnings("unchecked") - protected void sendResponse(Object result, String topic, @Nullable Object source, boolean returnTypeMessage) { + protected void sendResponse(Object result, @Nullable String topic, @Nullable Object source, boolean returnTypeMessage) { + if (topic == null && source instanceof Message) { + byte[] replyTopicBytes = getReplyTopic((Message) source); + if (replyTopicBytes != null) { + topic = new String(replyTopicBytes, StandardCharsets.UTF_8); + } + } if (!returnTypeMessage && topic == null) { this.logger.debug(() -> "No replyTopic to handle the reply: " + result); } @@ -482,14 +489,14 @@ else if (result instanceof Message) { iterableOfMessages = iterator.next() instanceof Message; } if (iterableOfMessages || this.splitIterables) { - ((Iterable) result).forEach(v -> { + for (V v : (Iterable) result) { if (v instanceof Message) { - this.replyTemplate.send((Message) v); + this.replyTemplate.send(checkHeaders(v, topic, source)); } else { this.replyTemplate.send(topic, v); } - }); + } } else { sendSingleResult(result, topic, source); @@ -506,7 +513,8 @@ private Message checkHeaders(Object result, String topic, @Nullable Object so MessageHeaders headers = reply.getHeaders(); boolean needsTopic = headers.get(KafkaHeaders.TOPIC) == null; boolean sourceIsMessage = source instanceof Message; - boolean needsCorrelation = headers.get(this.correlationHeaderName) == null && sourceIsMessage; + boolean needsCorrelation = headers.get(this.correlationHeaderName) == null && sourceIsMessage + && getCorrelationId((Message) source) != null; boolean needsPartition = headers.get(KafkaHeaders.PARTITION) == null && sourceIsMessage && getReplyPartition((Message) source) != null; if (needsTopic || needsCorrelation || needsPartition) { @@ -514,11 +522,10 @@ private Message checkHeaders(Object result, String topic, @Nullable Object so if (needsTopic) { builder.setHeader(KafkaHeaders.TOPIC, topic); } - if (needsCorrelation && sourceIsMessage) { - builder.setHeader(this.correlationHeaderName, - ((Message) source).getHeaders().get(this.correlationHeaderName)); + if (needsCorrelation) { + setCorrelationId(builder, (Message) source); } - if (sourceIsMessage && reply.getHeaders().get(KafkaHeaders.REPLY_PARTITION) == null) { + if (needsPartition) { setPartition(builder, (Message) source); } reply = builder.build(); @@ -571,6 +578,30 @@ private void sendReplyForMessageSource(Object result, String topic, Object sourc this.replyTemplate.send(builder.build()); } + private void setTopic(MessageBuilder builder, Message source) { + byte[] topicBytes = getReplyTopic(source); + if (topicBytes != null) { + builder.setHeader(KafkaHeaders.TOPIC, new String(topicBytes, StandardCharsets.UTF_8)); + } + } + + @Nullable + private byte[] getReplyTopic(Message source) { + return source.getHeaders().get(KafkaHeaders.REPLY_TOPIC, byte[].class); + } + + private void setCorrelationId(MessageBuilder builder, Message source) { + byte[] correlationIdBytes = getCorrelationId(source); + if (correlationIdBytes != null) { + builder.setHeader(this.correlationHeaderName, correlationIdBytes); + } + } + + @Nullable + private byte[] getCorrelationId(Message source) { + return source.getHeaders().get(this.correlationHeaderName, byte[].class); + } + private void setPartition(MessageBuilder builder, Message source) { byte[] partitionBytes = getReplyPartition(source); if (partitionBytes != null) {