Skip to content

Commit 9fee079

Browse files
committed
Fix ConnectionFieldTypeVisitor traversal control
Prior to this commit, the `ConnectionFieldTypeVisitor` would return `TraversalControl.ABORT` when a field with a parent Subscription type is encountered. This is done because this visitor should not consider fields under Subscription types as they are not candidates for pagination. This would also completely abort the visiting of all fields under Subscription types, for all other visitors. As a result, this prevents the decoration of data fetchers by the `ContextTypeVisitor` and leads to missing context information (security or observability). This only applies to 21.x GraphQL Java versions, as a bug was hiding this behavior in previous versions. This commit ensures that the `ConnectionFieldTypeVisitor` ignores fields located under Subscription operations but always return `TraversalControl.CONTINUE` to not completely ignore this part of the schema. Fixes gh-861
1 parent d7606f9 commit 9fee079

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

spring-graphql/src/main/java/org/springframework/graphql/data/pagination/ConnectionFieldTypeVisitor.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ public TraversalControl visitGraphQLFieldDefinition(
8787
GraphQLFieldsContainer parent = (GraphQLFieldsContainer) context.getParentNode();
8888
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(parent, fieldDefinition);
8989

90-
if (visitorHelper != null && visitorHelper.isSubscriptionType(parent)) {
91-
return TraversalControl.ABORT;
90+
if (visitorHelper != null && isUnderSubscriptionOperation(visitorHelper, context)) {
91+
return TraversalControl.CONTINUE;
9292
}
9393

9494
if (isConnectionField(fieldDefinition)) {
@@ -108,6 +108,13 @@ public TraversalControl visitGraphQLFieldDefinition(
108108
return TraversalControl.CONTINUE;
109109
}
110110

111+
private static boolean isUnderSubscriptionOperation(TypeVisitorHelper visitorHelper, TraverserContext<GraphQLSchemaElement> context) {
112+
return context.getBreadcrumbs().stream()
113+
.filter(GraphQLFieldsContainer.class::isInstance)
114+
.map(GraphQLFieldsContainer.class::cast)
115+
.anyMatch(visitorHelper::isSubscriptionType);
116+
}
117+
111118
private static boolean isConnectionField(GraphQLFieldDefinition field) {
112119
GraphQLObjectType type = getAsObjectType(field);
113120
if (type == null || !type.getName().endsWith("Connection")) {

spring-graphql/src/test/java/org/springframework/graphql/data/pagination/ConnectionFieldTypeVisitorTests.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,33 @@
1616

1717
package org.springframework.graphql.data.pagination;
1818

19+
import java.util.Arrays;
1920
import java.util.Collection;
21+
import java.util.HashMap;
22+
import java.util.HashSet;
2023
import java.util.List;
24+
import java.util.Map;
25+
import java.util.Set;
2126
import java.util.function.Consumer;
2227

2328
import graphql.execution.DataFetcherResult;
2429
import graphql.schema.DataFetcher;
2530
import graphql.schema.FieldCoordinates;
31+
import graphql.schema.GraphQLCodeRegistry;
2632
import graphql.schema.GraphQLFieldDefinition;
2733
import graphql.schema.GraphQLSchema;
34+
import graphql.schema.GraphQLSchemaElement;
35+
import graphql.schema.GraphQLTypeVisitor;
36+
import graphql.schema.GraphQLTypeVisitorStub;
2837
import graphql.schema.PropertyDataFetcher;
2938
import graphql.schema.SchemaTransformer;
39+
import graphql.schema.SchemaTraverser;
3040
import graphql.schema.idl.RuntimeWiring;
3141
import graphql.schema.idl.SchemaGenerator;
3242
import graphql.schema.idl.SchemaParser;
3343
import graphql.schema.idl.TypeDefinitionRegistry;
44+
import graphql.util.TraversalControl;
45+
import graphql.util.TraverserContext;
3446
import org.junit.jupiter.api.Nested;
3547
import org.junit.jupiter.api.Test;
3648
import reactor.core.publisher.Mono;
@@ -42,6 +54,7 @@
4254
import org.springframework.graphql.GraphQlSetup;
4355
import org.springframework.graphql.ResponseHelper;
4456
import org.springframework.graphql.execution.ConnectionTypeDefinitionConfigurer;
57+
import org.springframework.graphql.execution.TypeVisitorHelper;
4558

4659
import static org.assertj.core.api.Assertions.assertThat;
4760

@@ -223,6 +236,61 @@ else if (schemaSource instanceof String schemaContent) {
223236

224237
}
225238

239+
@Nested
240+
class TypeVisitorTests {
241+
242+
243+
@Test // gh-861
244+
void shouldNotAbortTraverseForSubscriptions() {
245+
String schemaContent = """
246+
type Query {
247+
greeting: String
248+
}
249+
type Subscription {
250+
puzzles: Puzzle
251+
}
252+
type Puzzle {
253+
title: String!
254+
author: Author
255+
}
256+
type Author {
257+
name: String!
258+
}
259+
""";
260+
261+
ConnectionFieldTypeVisitor connectionFieldTypeVisitor = ConnectionFieldTypeVisitor.create(List.of(new ListConnectionAdapter()));
262+
TrackingTypeVisitor trackingTypeVisitor = new TrackingTypeVisitor();
263+
visitSchema(schemaContent, connectionFieldTypeVisitor, trackingTypeVisitor);
264+
assertThat(trackingTypeVisitor.visitedDefinitions).contains("puzzles", "title", "author", "name");
265+
}
266+
267+
private static void visitSchema(String schemaContent, GraphQLTypeVisitor... typeVisitors) {
268+
TypeDefinitionRegistry registry = new SchemaParser().parse(schemaContent);
269+
new ConnectionTypeDefinitionConfigurer().configure(registry);
270+
GraphQLSchema schema = new SchemaGenerator().makeExecutableSchema(registry, RuntimeWiring.newRuntimeWiring().build());
271+
272+
GraphQLCodeRegistry.Builder outputCodeRegistry =
273+
GraphQLCodeRegistry.newCodeRegistry(schema.getCodeRegistry());
274+
Map<Class<?>, Object> vars = new HashMap<>(2);
275+
vars.put(GraphQLCodeRegistry.Builder.class, outputCodeRegistry);
276+
vars.put(TypeVisitorHelper.class, TypeVisitorHelper.create(schema));
277+
278+
List<GraphQLTypeVisitor> visitorsToUse = Arrays.asList(typeVisitors);
279+
new SchemaTraverser().depthFirstFullSchema(visitorsToUse, schema, vars);
280+
}
281+
282+
static class TrackingTypeVisitor extends GraphQLTypeVisitorStub {
283+
284+
Set<String> visitedDefinitions = new HashSet<>();
285+
286+
@Override
287+
public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition node, TraverserContext<GraphQLSchemaElement> context) {
288+
this.visitedDefinitions.add(node.getName());
289+
return TraversalControl.CONTINUE;
290+
}
291+
}
292+
}
293+
226294

227295

228296
private static class ListConnectionAdapter implements ConnectionAdapter {

0 commit comments

Comments
 (0)