diff --git a/core/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/function/FunctionPostProcessingTests.java b/core/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/function/FunctionPostProcessingTests.java index 24bcc3d0b..621251304 100644 --- a/core/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/function/FunctionPostProcessingTests.java +++ b/core/spring-cloud-stream-integration-tests/src/test/java/org/springframework/cloud/stream/function/FunctionPostProcessingTests.java @@ -17,6 +17,7 @@ package org.springframework.cloud.stream.function; import java.util.function.Function; +import java.util.function.Supplier; import org.junit.jupiter.api.Test; @@ -29,8 +30,10 @@ import org.springframework.cloud.stream.binder.test.TestChannelBinderConfiguration; import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; import org.springframework.integration.support.MessageBuilder; import org.springframework.messaging.Message; +import org.springframework.messaging.support.GenericMessage; import static org.assertj.core.api.Assertions.assertThat; @@ -76,6 +79,22 @@ void successfulPostProcessingOfSingleFunction() { } } + @Test + void successfulPostProcessingOfSupplierFunctionCompposition() throws Exception { + System.clearProperty("spring.cloud.function.definition"); + try (ConfigurableApplicationContext context = new SpringApplicationBuilder( + TestChannelBinderConfiguration.getCompleteConfiguration(SupplierPostProcessingTestConfiguration.class)) + .web(WebApplicationType.NONE).run("--spring.jmx.enabled=false", + "--spring.cloud.function.definition=hello|uppercase", + "--spring.cloud.stream.bindings.hellouppercase-out-0.producer.poller.fixed-delay=100")) { + Thread.sleep(1000); + OutputDestination outputDestination = context.getBean(OutputDestination.class); + + assertThat(outputDestination.receive(5000, "hellouppercase-out-0").getPayload()).isEqualTo("HELLO".getBytes()); + assertThat(context.getBean(SupplierPostProcessingTestConfiguration.class).postProcessed).isTrue(); + } + } + @Test void noPostProcessingOnError() { System.clearProperty("spring.cloud.function.definition"); @@ -207,6 +226,31 @@ public Function reverse() { } } + @EnableAutoConfiguration + @Configuration + public static class SupplierPostProcessingTestConfiguration { + + public static boolean postProcessed; + + @Bean + public Supplier> hello() { + return () -> new GenericMessage<>("hello"); + } + + @Bean + public Function uppercase() { + return new PostProcessingFunction() { + public String apply(String input) { + return input.toUpperCase(); + } + + public void postProcess(Message result) { + postProcessed = true; + } + }; + } + } + private static class SingleFunctionPostProcessingFunction implements PostProcessingFunction { private boolean success; diff --git a/core/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/function/FunctionConfiguration.java b/core/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/function/FunctionConfiguration.java index c051e777b..b6d80bc2f 100644 --- a/core/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/function/FunctionConfiguration.java +++ b/core/spring-cloud-stream/src/main/java/org/springframework/cloud/stream/function/FunctionConfiguration.java @@ -107,6 +107,7 @@ import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.MessagingException; import org.springframework.messaging.SubscribableChannel; +import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.scheduling.TaskScheduler; import org.springframework.scheduling.Trigger; import org.springframework.scheduling.support.CronTrigger; @@ -235,15 +236,22 @@ InitializingBean supplierInitializer(FunctionCatalog functionCatalog, StreamFunc } if (functionWrapper != null) { + FunctionInvocationWrapper postProcessor = functionWrapper; IntegrationFlow integrationFlow = integrationFlowFromProvidedSupplier(new PartitionAwareFunctionWrapper(functionWrapper, context, producerProperties), pollable, context, taskScheduler, producerProperties, outputName) + .intercept(new ChannelInterceptor() { + public void postSend(Message message, MessageChannel channel, boolean sent) { + postProcessor.postProcess(); + } + }) .route(Message.class, message -> { if (message.getHeaders().get("spring.cloud.stream.sendto.destination") != null) { String destinationName = (String) message.getHeaders().get("spring.cloud.stream.sendto.destination"); return streamBridge.resolveDestination(destinationName, producerProperties, null); } return outputName; - }).get(); + }) + .get(); IntegrationFlow postProcessedFlow = (IntegrationFlow) context.getAutowireCapableBeanFactory() .initializeBean(integrationFlow, integrationFlowName); context.registerBean(integrationFlowName, IntegrationFlow.class, () -> postProcessedFlow);