Skip to content

propagate scope in async failures #3950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,7 @@ protected void handleAsyncFailure() {
// We will give up on retrying with the remaining copied and failed Records.
for (FailedRecordTuple<K, V> copyFailedRecord : copyFailedRecords) {
try {
invokeErrorHandlerBySingleRecord(copyFailedRecord);
copyFailedRecord.observation.scoped(() -> invokeErrorHandlerBySingleRecord(copyFailedRecord));
}
catch (Exception e) {
this.logger.warn(() ->
Expand Down Expand Up @@ -3432,8 +3432,13 @@ private Collection<ConsumerRecord<K, V>> getHighestOffsetRecords(ConsumerRecords
.values();
}

private Observation getCurrentObservation() {
Observation currentObservation = this.observationRegistry.getCurrentObservation();
return currentObservation == null ? Observation.NOOP : currentObservation;
}

private void callbackForAsyncFailure(ConsumerRecord<K, V> cRecord, RuntimeException ex) {
this.failedRecords.addLast(new FailedRecordTuple<>(cRecord, ex));
this.failedRecords.addLast(new FailedRecordTuple<>(cRecord, ex, getCurrentObservation()));
}

@Override
Expand Down Expand Up @@ -4050,6 +4055,6 @@ private static class StopAfterFenceException extends KafkaException {

}

private record FailedRecordTuple<K, V>(ConsumerRecord<K, V> record, RuntimeException ex) { }
private record FailedRecordTuple<K, V>(ConsumerRecord<K, V> record, RuntimeException ex, Observation observation) { }

}
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ else if (!(result instanceof CompletableFuture<?>)) {
}

completableFutureResult.whenComplete((r, t) -> {
try {
try (var scope = observation.openScope()) {
if (t == null) {
asyncSuccess(r, replyTopic, source, messageReturnType);
if (isAsyncReplies()) {
Expand Down Expand Up @@ -736,13 +736,15 @@ protected void asyncFailure(Object request, @Nullable Acknowledgment acknowledgm
"Async Fail", Objects.requireNonNull(source).getPayload()), cause));
}
catch (Throwable ex) {
this.logger.error(t, () -> "Future, Mono, or suspend function was completed with an exception for " + source);
acknowledge(acknowledgment);
if (canAsyncRetry(request, ex) && this.asyncRetryCallback != null) {
@SuppressWarnings("unchecked")
ConsumerRecord<K, V> record = (ConsumerRecord<K, V>) request;
this.asyncRetryCallback.accept(record, (RuntimeException) ex);
}
else {
this.logger.error(ex, () -> "Future, Mono, or suspend function was completed with an exception for " + source);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,16 @@
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.tracing.Span;
import io.micrometer.tracing.TraceContext;
import io.micrometer.tracing.Tracer;
import io.micrometer.tracing.handler.DefaultTracingObservationHandler;
import io.micrometer.tracing.handler.PropagatingReceiverTracingObservationHandler;
import io.micrometer.tracing.handler.PropagatingSenderTracingObservationHandler;
import io.micrometer.tracing.handler.TracingAwareMeterObservationHandler;
import io.micrometer.tracing.propagation.Propagator;
import io.micrometer.tracing.test.simple.SimpleSpan;
import io.micrometer.tracing.test.simple.SimpleTraceContext;
import io.micrometer.tracing.test.simple.SimpleTracer;
import org.apache.kafka.clients.admin.AdminClientConfig;
import org.apache.kafka.clients.consumer.Consumer;
Expand All @@ -70,8 +71,10 @@
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.kafka.KafkaException;
import org.springframework.kafka.annotation.DltHandler;
import org.springframework.kafka.annotation.EnableKafka;
import org.springframework.kafka.annotation.KafkaListener;
import org.springframework.kafka.annotation.RetryableTopic;
import org.springframework.kafka.config.ConcurrentKafkaListenerContainerFactory;
import org.springframework.kafka.config.KafkaListenerEndpointRegistry;
import org.springframework.kafka.core.ConsumerFactory;
Expand All @@ -80,6 +83,7 @@
import org.springframework.kafka.core.KafkaAdmin;
import org.springframework.kafka.core.KafkaTemplate;
import org.springframework.kafka.core.ProducerFactory;
import org.springframework.kafka.listener.ContainerProperties;
import org.springframework.kafka.listener.MessageListenerContainer;
import org.springframework.kafka.listener.RecordInterceptor;
import org.springframework.kafka.requestreply.ReplyingKafkaTemplate;
Expand All @@ -90,6 +94,9 @@
import org.springframework.kafka.test.context.EmbeddedKafka;
import org.springframework.kafka.test.utils.KafkaTestUtils;
import org.springframework.messaging.handler.annotation.SendTo;
import org.springframework.retry.annotation.Backoff;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
import org.springframework.util.StringUtils;
Expand All @@ -113,7 +120,8 @@
@EmbeddedKafka(topics = {ObservationTests.OBSERVATION_TEST_1, ObservationTests.OBSERVATION_TEST_2,
ObservationTests.OBSERVATION_TEST_3, ObservationTests.OBSERVATION_TEST_4, ObservationTests.OBSERVATION_REPLY,
ObservationTests.OBSERVATION_RUNTIME_EXCEPTION, ObservationTests.OBSERVATION_ERROR,
ObservationTests.OBSERVATION_TRACEPARENT_DUPLICATE}, partitions = 1)
ObservationTests.OBSERVATION_TRACEPARENT_DUPLICATE, ObservationTests.OBSERVATION_ASYNC_FAILURE_TEST,
ObservationTests.OBSERVATION_ASYNC_FAILURE_WITH_RETRY_TEST}, partitions = 1)
@DirtiesContext
public class ObservationTests {

Expand All @@ -137,6 +145,55 @@ public class ObservationTests {

public final static String OBSERVATION_TRACEPARENT_DUPLICATE = "observation.traceparent.duplicate";

public final static String OBSERVATION_ASYNC_FAILURE_TEST = "observation.async.failure.test";

public final static String OBSERVATION_ASYNC_FAILURE_WITH_RETRY_TEST = "observation.async.failure.retry.test";

@Test
void asyncRetryScopePropagation(@Autowired AsyncFailureListener asyncFailureListener,
@Autowired KafkaTemplate<Integer, String> template,
@Autowired SimpleTracer tracer,
@Autowired ObservationRegistry observationRegistry) throws InterruptedException {

// Clear any previous spans
tracer.getSpans().clear();

// Create an observation scope to ensure we have a proper trace context
var testObservation = Observation.createNotStarted("test.message.send", observationRegistry);

// Send a message within the observation scope to ensure trace context is propagated
testObservation.observe(() -> {
try {
template.send(OBSERVATION_ASYNC_FAILURE_TEST, "trigger-async-failure").get(5, TimeUnit.SECONDS);
}
catch (Exception e) {
throw new RuntimeException("Failed to send message", e);
}
});

// Wait for the listener to process the message (initial + retry + DLT = 3 invocations)
assertThat(asyncFailureListener.asyncFailureLatch.await(100000, TimeUnit.SECONDS)).isTrue();

// Verify that the captured spans from the listener contexts are all part of the same trace
// This demonstrates that the tracing context propagates correctly through the retry mechanism
Deque<SimpleSpan> spans = tracer.getSpans();
assertThat(spans).hasSizeGreaterThanOrEqualTo(4); // template + listener + retry + DLT spans

// Verify that spans were captured for each phase and belong to the same trace
assertThat(asyncFailureListener.capturedSpanInListener).isNotNull();
assertThat(asyncFailureListener.capturedSpanInRetry).isNotNull();
assertThat(asyncFailureListener.capturedSpanInDlt).isNotNull();

// All spans should have the same trace ID, demonstrating trace continuity
var originalTraceId = asyncFailureListener.capturedSpanInListener.getTraceId();
assertThat(originalTraceId).isNotBlank();
assertThat(asyncFailureListener.capturedSpanInRetry.getTraceId()).isEqualTo(originalTraceId);
assertThat(asyncFailureListener.capturedSpanInDlt.getTraceId()).isEqualTo(originalTraceId);

// Clear any previous spans
tracer.getSpans().clear();
}

@Test
void endToEnd(@Autowired Listener listener, @Autowired KafkaTemplate<Integer, String> template,
@Autowired SimpleTracer tracer, @Autowired KafkaListenerEndpointRegistry rler,
Expand Down Expand Up @@ -628,6 +685,11 @@ ConcurrentKafkaListenerContainerFactory<Integer, String> kafkaListenerContainerF
if (container.getListenerId().equals("obs3")) {
container.setKafkaAdmin(this.mockAdmin);
}
if (container.getListenerId().contains("asyncFailure")) {
// Enable async acks to trigger async failure handling
container.getContainerProperties().setAsyncAcks(true);
container.getContainerProperties().setAckMode(ContainerProperties.AckMode.MANUAL);
}
if (container.getListenerId().equals("obs4")) {
container.setRecordInterceptor(new RecordInterceptor<>() {

Expand Down Expand Up @@ -662,17 +724,17 @@ MeterRegistry meterRegistry() {

@Bean
ObservationRegistry observationRegistry(Tracer tracer, Propagator propagator, MeterRegistry meterRegistry) {
TestObservationRegistry observationRegistry = TestObservationRegistry.create();
var observationRegistry = ObservationRegistry.create();
observationRegistry.observationConfig().observationHandler(
// Composite will pick the first matching handler
new ObservationHandler.FirstMatchingCompositeObservationHandler(
// This is responsible for creating a child span on the sender side
new PropagatingSenderTracingObservationHandler<>(tracer, propagator),
// This is responsible for creating a span on the receiver side
new PropagatingReceiverTracingObservationHandler<>(tracer, propagator),
// This is responsible for creating a child span on the sender side
new PropagatingSenderTracingObservationHandler<>(tracer, propagator),
// This is responsible for creating a default span
new DefaultTracingObservationHandler(tracer)))
.observationHandler(new DefaultMeterObservationHandler(meterRegistry));
.observationHandler(new TracingAwareMeterObservationHandler<>(new DefaultMeterObservationHandler(meterRegistry), tracer));
return observationRegistry;
}

Expand All @@ -683,29 +745,41 @@ Propagator propagator(Tracer tracer) {
// List of headers required for tracing propagation
@Override
public List<String> fields() {
return Arrays.asList("foo", "bar");
return Arrays.asList("traceId", "spanId", "foo", "bar");
}

// This is called on the producer side when the message is being sent
// Normally we would pass information from tracing context - for tests we don't need to
@Override
public <C> void inject(TraceContext context, @Nullable C carrier, Setter<C> setter) {
setter.set(carrier, "foo", "some foo value");
setter.set(carrier, "bar", "some bar value");

setter.set(carrier, "traceId", context.traceId());
setter.set(carrier, "spanId", context.spanId());

// Add a traceparent header to simulate W3C trace context
setter.set(carrier, "traceparent", "traceparent-from-propagator");
}

// This is called on the consumer side when the message is consumed
// Normally we would use tools like Extractor from tracing but for tests we are just manually creating a span
@Override
public <C> Span.Builder extract(C carrier, Getter<C> getter) {
String foo = getter.get(carrier, "foo");
String bar = getter.get(carrier, "bar");
return tracer.spanBuilder()

var traceId = getter.get(carrier, "traceId");
var spanId = getter.get(carrier, "spanId");

Span.Builder spanBuilder = tracer.spanBuilder()
.tag("foo", foo)
.tag("bar", bar);

var traceContext = new SimpleTraceContext();
traceContext.setTraceId(traceId);
traceContext.setSpanId(spanId);
spanBuilder = spanBuilder.setParent(traceContext);

return spanBuilder;
}
};
}
Expand All @@ -720,6 +794,15 @@ ExceptionListener exceptionListener() {
return new ExceptionListener();
}

@Bean
AsyncFailureListener asyncFailureListener(SimpleTracer tracer) {
return new AsyncFailureListener(tracer);
}

@Bean
public TaskScheduler taskExecutor() {
return new ThreadPoolTaskScheduler();
}
}

public static class Listener {
Expand Down Expand Up @@ -801,4 +884,54 @@ Mono<Void> receive1(ConsumerRecord<Object, Object> record) {

}

public static class AsyncFailureListener {

final CountDownLatch asyncFailureLatch = new CountDownLatch(3);

volatile @Nullable SimpleSpan capturedSpanInListener;

volatile @Nullable SimpleSpan capturedSpanInRetry;

volatile @Nullable SimpleSpan capturedSpanInDlt;

private final SimpleTracer tracer;

public AsyncFailureListener(SimpleTracer tracer) {
this.tracer = tracer;
}

@RetryableTopic(
attempts = "2",
backoff = @Backoff(delay = 1000)
)
@KafkaListener(id = "asyncFailure", topics = OBSERVATION_ASYNC_FAILURE_TEST)
CompletableFuture<Void> handleAsync(ConsumerRecord<Integer, String> record) {

// Use topic name to distinguish between original and retry calls
String topicName = record.topic();

if (topicName.equals(OBSERVATION_ASYNC_FAILURE_TEST)) {
// This is the original call
this.capturedSpanInListener = this.tracer.currentSpan();
}
else {
// This is a retry call (topic name will be different for retry topics)
this.capturedSpanInRetry = this.tracer.currentSpan();
}

this.asyncFailureLatch.countDown();

// Return a failed CompletableFuture to trigger async failure handling
return CompletableFuture.supplyAsync(() -> {
throw new RuntimeException("Async failure for observation test");
});
}

@DltHandler
void handleDlt(ConsumerRecord<Integer, String> record, Exception exception) {
this.capturedSpanInDlt = this.tracer.currentSpan();
this.asyncFailureLatch.countDown();
}
}

}