Skip to content

Commit

Permalink
Fix ConnectionFieldTypeVisitor traversal control
Browse files Browse the repository at this point in the history
Prior to this commit, the `ConnectionFieldTypeVisitor` would return
`TraversalControl.ABORT` when a field with a parent Subscription type is
encountered. This is done because this visitor should not consider
fields under Subscription types as they are not candidates for
pagination. This would also completely abort the visiting of all fields
under Subscription types, for all other visitors. As a result, this
prevents the decoration of data fetchers by the `ContextTypeVisitor` and
leads to missing context information (security or observability).

This only applies to 21.x GraphQL Java versions, as a bug was hiding
this behavior in previous versions.

This commit ensures that the `ConnectionFieldTypeVisitor` ignores fields
located under Subscription operations but always return
`TraversalControl.CONTINUE` to not completely ignore this part of the
schema.

Fixes gh-861
  • Loading branch information
bclozel committed Dec 4, 2023
1 parent d7606f9 commit 9fee079
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ public TraversalControl visitGraphQLFieldDefinition(
GraphQLFieldsContainer parent = (GraphQLFieldsContainer) context.getParentNode();
DataFetcher<?> dataFetcher = codeRegistry.getDataFetcher(parent, fieldDefinition);

if (visitorHelper != null && visitorHelper.isSubscriptionType(parent)) {
return TraversalControl.ABORT;
if (visitorHelper != null && isUnderSubscriptionOperation(visitorHelper, context)) {
return TraversalControl.CONTINUE;
}

if (isConnectionField(fieldDefinition)) {
Expand All @@ -108,6 +108,13 @@ public TraversalControl visitGraphQLFieldDefinition(
return TraversalControl.CONTINUE;
}

private static boolean isUnderSubscriptionOperation(TypeVisitorHelper visitorHelper, TraverserContext<GraphQLSchemaElement> context) {
return context.getBreadcrumbs().stream()
.filter(GraphQLFieldsContainer.class::isInstance)
.map(GraphQLFieldsContainer.class::cast)
.anyMatch(visitorHelper::isSubscriptionType);
}

private static boolean isConnectionField(GraphQLFieldDefinition field) {
GraphQLObjectType type = getAsObjectType(field);
if (type == null || !type.getName().endsWith("Connection")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,33 @@

package org.springframework.graphql.data.pagination;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;

import graphql.execution.DataFetcherResult;
import graphql.schema.DataFetcher;
import graphql.schema.FieldCoordinates;
import graphql.schema.GraphQLCodeRegistry;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLSchema;
import graphql.schema.GraphQLSchemaElement;
import graphql.schema.GraphQLTypeVisitor;
import graphql.schema.GraphQLTypeVisitorStub;
import graphql.schema.PropertyDataFetcher;
import graphql.schema.SchemaTransformer;
import graphql.schema.SchemaTraverser;
import graphql.schema.idl.RuntimeWiring;
import graphql.schema.idl.SchemaGenerator;
import graphql.schema.idl.SchemaParser;
import graphql.schema.idl.TypeDefinitionRegistry;
import graphql.util.TraversalControl;
import graphql.util.TraverserContext;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
Expand All @@ -42,6 +54,7 @@
import org.springframework.graphql.GraphQlSetup;
import org.springframework.graphql.ResponseHelper;
import org.springframework.graphql.execution.ConnectionTypeDefinitionConfigurer;
import org.springframework.graphql.execution.TypeVisitorHelper;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -223,6 +236,61 @@ else if (schemaSource instanceof String schemaContent) {

}

@Nested
class TypeVisitorTests {


@Test // gh-861
void shouldNotAbortTraverseForSubscriptions() {
String schemaContent = """
type Query {
greeting: String
}
type Subscription {
puzzles: Puzzle
}
type Puzzle {
title: String!
author: Author
}
type Author {
name: String!
}
""";

ConnectionFieldTypeVisitor connectionFieldTypeVisitor = ConnectionFieldTypeVisitor.create(List.of(new ListConnectionAdapter()));
TrackingTypeVisitor trackingTypeVisitor = new TrackingTypeVisitor();
visitSchema(schemaContent, connectionFieldTypeVisitor, trackingTypeVisitor);
assertThat(trackingTypeVisitor.visitedDefinitions).contains("puzzles", "title", "author", "name");
}

private static void visitSchema(String schemaContent, GraphQLTypeVisitor... typeVisitors) {
TypeDefinitionRegistry registry = new SchemaParser().parse(schemaContent);
new ConnectionTypeDefinitionConfigurer().configure(registry);
GraphQLSchema schema = new SchemaGenerator().makeExecutableSchema(registry, RuntimeWiring.newRuntimeWiring().build());

GraphQLCodeRegistry.Builder outputCodeRegistry =
GraphQLCodeRegistry.newCodeRegistry(schema.getCodeRegistry());
Map<Class<?>, Object> vars = new HashMap<>(2);
vars.put(GraphQLCodeRegistry.Builder.class, outputCodeRegistry);
vars.put(TypeVisitorHelper.class, TypeVisitorHelper.create(schema));

List<GraphQLTypeVisitor> visitorsToUse = Arrays.asList(typeVisitors);
new SchemaTraverser().depthFirstFullSchema(visitorsToUse, schema, vars);
}

static class TrackingTypeVisitor extends GraphQLTypeVisitorStub {

Set<String> visitedDefinitions = new HashSet<>();

@Override
public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition node, TraverserContext<GraphQLSchemaElement> context) {
this.visitedDefinitions.add(node.getName());
return TraversalControl.CONTINUE;
}
}
}



private static class ListConnectionAdapter implements ConnectionAdapter {
Expand Down

0 comments on commit 9fee079

Please sign in to comment.