diff --git a/spring-kafka/src/main/java/org/springframework/kafka/listener/CommonDelegatingErrorHandler.java b/spring-kafka/src/main/java/org/springframework/kafka/listener/CommonDelegatingErrorHandler.java index 9d0deac24d..d510701f35 100644 --- a/spring-kafka/src/main/java/org/springframework/kafka/listener/CommonDelegatingErrorHandler.java +++ b/spring-kafka/src/main/java/org/springframework/kafka/listener/CommonDelegatingErrorHandler.java @@ -70,15 +70,7 @@ public void setErrorHandlers(Map, CommonErrorHandler> Assert.notNull(delegates, "'delegates' cannot be null"); this.delegates.clear(); this.delegates.putAll(delegates); - checkDelegates(); - updateClassifier(delegates); - } - - private void updateClassifier(Map, CommonErrorHandler> delegates) { - Map, Boolean> classifications = delegates.keySet().stream() - .map(commonErrorHandler -> Map.entry(commonErrorHandler, true)) - .collect(Collectors.toMap(Entry::getKey, Entry::getValue)); - this.classifier = new BinaryExceptionClassifier(classifications); + checkDelegatesAndUpdateClassifier(this.delegates); } /** @@ -119,12 +111,17 @@ public void setAckAfterHandle(boolean ack) { * @param handler the handler. */ public void addDelegate(Class throwable, CommonErrorHandler handler) { - this.delegates.put(throwable, handler); - checkDelegates(); + Map, CommonErrorHandler> delegatesToCheck = new LinkedHashMap<>(this.delegates); + delegatesToCheck.put(throwable, handler); + checkDelegatesAndUpdateClassifier(delegatesToCheck); + this.delegates.clear(); + this.delegates.putAll(delegatesToCheck); } @SuppressWarnings("deprecation") - private void checkDelegates() { + private void checkDelegatesAndUpdateClassifier(Map, + CommonErrorHandler> delegatesToCheck) { + boolean ackAfterHandle = this.defaultErrorHandler.isAckAfterHandle(); boolean seeksAfterHandling = this.defaultErrorHandler.seeksAfterHandling(); this.delegates.values().forEach(handler -> { @@ -133,6 +130,14 @@ private void checkDelegates() { Assert.isTrue(seeksAfterHandling == handler.seeksAfterHandling(), "All delegates must return the same value when calling 'seeksAfterHandling()'"); }); + updateClassifier(delegatesToCheck); + } + + private void updateClassifier(Map, CommonErrorHandler> delegates) { + Map, Boolean> classifications = delegates.keySet().stream() + .map(commonErrorHandler -> Map.entry(commonErrorHandler, true)) + .collect(Collectors.toMap(Entry::getKey, Entry::getValue)); + this.classifier = new BinaryExceptionClassifier(classifications); } @Override diff --git a/spring-kafka/src/test/java/org/springframework/kafka/listener/CommonDelegatingErrorHandlerTests.java b/spring-kafka/src/test/java/org/springframework/kafka/listener/CommonDelegatingErrorHandlerTests.java index 38db6405ba..1a25253a41 100644 --- a/spring-kafka/src/test/java/org/springframework/kafka/listener/CommonDelegatingErrorHandlerTests.java +++ b/spring-kafka/src/test/java/org/springframework/kafka/listener/CommonDelegatingErrorHandlerTests.java @@ -16,6 +16,7 @@ package org.springframework.kafka.listener; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -31,6 +32,7 @@ import org.springframework.kafka.KafkaException; import org.springframework.kafka.core.KafkaProducerException; +import org.springframework.kafka.test.utils.KafkaTestUtils; /** * Tests for {@link CommonDelegatingErrorHandler}. @@ -134,7 +136,7 @@ void testDelegateForThrowableCauseIsAppliedWhenCauseTraversingIsEnabled() { } @Test - @SuppressWarnings("ConstantConditions") + @SuppressWarnings({ "ConstantConditions", "unchecked" }) void testDelegateForClassifiableThrowableCauseIsAppliedWhenCauseTraversingIsEnabled() { var defaultHandler = mock(CommonErrorHandler.class); @@ -147,6 +149,10 @@ void testDelegateForClassifiableThrowableCauseIsAppliedWhenCauseTraversingIsEnab delegatingErrorHandler.setErrorHandlers(Map.of( KafkaException.class, directCauseErrorHandler )); + delegatingErrorHandler.addDelegate(IllegalStateException.class, mock(CommonErrorHandler.class)); + assertThat(KafkaTestUtils.getPropertyValue(delegatingErrorHandler, "classifier.classified", Map.class).keySet()) + .contains(IllegalStateException.class); + delegatingErrorHandler.handleRemaining(exc, Collections.emptyList(), mock(Consumer.class), mock(MessageListenerContainer.class));