Skip to content

Commit

Permalink
GH-3009 Add post processing support for Supplier
Browse files Browse the repository at this point in the history
Resolves #3009
  • Loading branch information
olegz committed Sep 26, 2024
1 parent 090ccfa commit 0336ffd
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.cloud.stream.function;

import java.util.function.Function;
import java.util.function.Supplier;

import org.junit.jupiter.api.Test;

Expand All @@ -29,8 +30,10 @@
import org.springframework.cloud.stream.binder.test.TestChannelBinderConfiguration;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.GenericMessage;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -76,6 +79,22 @@ void successfulPostProcessingOfSingleFunction() {
}
}

@Test
void successfulPostProcessingOfSupplierFunctionCompposition() throws Exception {
System.clearProperty("spring.cloud.function.definition");
try (ConfigurableApplicationContext context = new SpringApplicationBuilder(
TestChannelBinderConfiguration.getCompleteConfiguration(SupplierPostProcessingTestConfiguration.class))
.web(WebApplicationType.NONE).run("--spring.jmx.enabled=false",
"--spring.cloud.function.definition=hello|uppercase",
"--spring.cloud.stream.bindings.hellouppercase-out-0.producer.poller.fixed-delay=100")) {
Thread.sleep(1000);
OutputDestination outputDestination = context.getBean(OutputDestination.class);

assertThat(outputDestination.receive(5000, "hellouppercase-out-0").getPayload()).isEqualTo("HELLO".getBytes());
assertThat(context.getBean(SupplierPostProcessingTestConfiguration.class).postProcessed).isTrue();
}
}

@Test
void noPostProcessingOnError() {
System.clearProperty("spring.cloud.function.definition");
Expand Down Expand Up @@ -207,6 +226,31 @@ public Function<String, String> reverse() {
}
}

@EnableAutoConfiguration
@Configuration
public static class SupplierPostProcessingTestConfiguration {

public static boolean postProcessed;

@Bean
public Supplier<Message<String>> hello() {
return () -> new GenericMessage<>("hello");
}

@Bean
public Function<String, String> uppercase() {
return new PostProcessingFunction<String, String>() {
public String apply(String input) {
return input.toUpperCase();
}

public void postProcess(Message<String> result) {
postProcessed = true;
}
};
}
}

private static class SingleFunctionPostProcessingFunction implements PostProcessingFunction<String, String> {

private boolean success;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.SubscribableChannel;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.Trigger;
import org.springframework.scheduling.support.CronTrigger;
Expand Down Expand Up @@ -235,15 +236,22 @@ InitializingBean supplierInitializer(FunctionCatalog functionCatalog, StreamFunc
}

if (functionWrapper != null) {
FunctionInvocationWrapper postProcessor = functionWrapper;
IntegrationFlow integrationFlow = integrationFlowFromProvidedSupplier(new PartitionAwareFunctionWrapper(functionWrapper, context, producerProperties),
pollable, context, taskScheduler, producerProperties, outputName)
.intercept(new ChannelInterceptor() {
public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
postProcessor.postProcess();
}
})
.route(Message.class, message -> {
if (message.getHeaders().get("spring.cloud.stream.sendto.destination") != null) {
String destinationName = (String) message.getHeaders().get("spring.cloud.stream.sendto.destination");
return streamBridge.resolveDestination(destinationName, producerProperties, null);
}
return outputName;
}).get();
})
.get();
IntegrationFlow postProcessedFlow = (IntegrationFlow) context.getAutowireCapableBeanFactory()
.initializeBean(integrationFlow, integrationFlowName);
context.registerBean(integrationFlowName, IntegrationFlow.class, () -> postProcessedFlow);
Expand Down

0 comments on commit 0336ffd

Please sign in to comment.