|
16 | 16 |
|
17 | 17 | package org.springframework.graphql.data.pagination;
|
18 | 18 |
|
| 19 | +import java.util.Arrays; |
19 | 20 | import java.util.Collection;
|
| 21 | +import java.util.HashMap; |
| 22 | +import java.util.HashSet; |
20 | 23 | import java.util.List;
|
| 24 | +import java.util.Map; |
| 25 | +import java.util.Set; |
21 | 26 | import java.util.function.Consumer;
|
22 | 27 |
|
23 | 28 | import graphql.execution.DataFetcherResult;
|
24 | 29 | import graphql.schema.DataFetcher;
|
25 | 30 | import graphql.schema.FieldCoordinates;
|
| 31 | +import graphql.schema.GraphQLCodeRegistry; |
26 | 32 | import graphql.schema.GraphQLFieldDefinition;
|
27 | 33 | import graphql.schema.GraphQLSchema;
|
| 34 | +import graphql.schema.GraphQLSchemaElement; |
| 35 | +import graphql.schema.GraphQLTypeVisitor; |
| 36 | +import graphql.schema.GraphQLTypeVisitorStub; |
28 | 37 | import graphql.schema.PropertyDataFetcher;
|
29 | 38 | import graphql.schema.SchemaTransformer;
|
| 39 | +import graphql.schema.SchemaTraverser; |
30 | 40 | import graphql.schema.idl.RuntimeWiring;
|
31 | 41 | import graphql.schema.idl.SchemaGenerator;
|
32 | 42 | import graphql.schema.idl.SchemaParser;
|
33 | 43 | import graphql.schema.idl.TypeDefinitionRegistry;
|
| 44 | +import graphql.util.TraversalControl; |
| 45 | +import graphql.util.TraverserContext; |
34 | 46 | import org.junit.jupiter.api.Nested;
|
35 | 47 | import org.junit.jupiter.api.Test;
|
36 | 48 | import reactor.core.publisher.Mono;
|
|
42 | 54 | import org.springframework.graphql.GraphQlSetup;
|
43 | 55 | import org.springframework.graphql.ResponseHelper;
|
44 | 56 | import org.springframework.graphql.execution.ConnectionTypeDefinitionConfigurer;
|
| 57 | +import org.springframework.graphql.execution.TypeVisitorHelper; |
45 | 58 |
|
46 | 59 | import static org.assertj.core.api.Assertions.assertThat;
|
47 | 60 |
|
@@ -223,6 +236,61 @@ else if (schemaSource instanceof String schemaContent) {
|
223 | 236 |
|
224 | 237 | }
|
225 | 238 |
|
| 239 | + @Nested |
| 240 | + class TypeVisitorTests { |
| 241 | + |
| 242 | + |
| 243 | + @Test // gh-861 |
| 244 | + void shouldNotAbortTraverseForSubscriptions() { |
| 245 | + String schemaContent = """ |
| 246 | + type Query { |
| 247 | + greeting: String |
| 248 | + } |
| 249 | + type Subscription { |
| 250 | + puzzles: Puzzle |
| 251 | + } |
| 252 | + type Puzzle { |
| 253 | + title: String! |
| 254 | + author: Author |
| 255 | + } |
| 256 | + type Author { |
| 257 | + name: String! |
| 258 | + } |
| 259 | + """; |
| 260 | + |
| 261 | + ConnectionFieldTypeVisitor connectionFieldTypeVisitor = ConnectionFieldTypeVisitor.create(List.of(new ListConnectionAdapter())); |
| 262 | + TrackingTypeVisitor trackingTypeVisitor = new TrackingTypeVisitor(); |
| 263 | + visitSchema(schemaContent, connectionFieldTypeVisitor, trackingTypeVisitor); |
| 264 | + assertThat(trackingTypeVisitor.visitedDefinitions).contains("puzzles", "title", "author", "name"); |
| 265 | + } |
| 266 | + |
| 267 | + private static void visitSchema(String schemaContent, GraphQLTypeVisitor... typeVisitors) { |
| 268 | + TypeDefinitionRegistry registry = new SchemaParser().parse(schemaContent); |
| 269 | + new ConnectionTypeDefinitionConfigurer().configure(registry); |
| 270 | + GraphQLSchema schema = new SchemaGenerator().makeExecutableSchema(registry, RuntimeWiring.newRuntimeWiring().build()); |
| 271 | + |
| 272 | + GraphQLCodeRegistry.Builder outputCodeRegistry = |
| 273 | + GraphQLCodeRegistry.newCodeRegistry(schema.getCodeRegistry()); |
| 274 | + Map<Class<?>, Object> vars = new HashMap<>(2); |
| 275 | + vars.put(GraphQLCodeRegistry.Builder.class, outputCodeRegistry); |
| 276 | + vars.put(TypeVisitorHelper.class, TypeVisitorHelper.create(schema)); |
| 277 | + |
| 278 | + List<GraphQLTypeVisitor> visitorsToUse = Arrays.asList(typeVisitors); |
| 279 | + new SchemaTraverser().depthFirstFullSchema(visitorsToUse, schema, vars); |
| 280 | + } |
| 281 | + |
| 282 | + static class TrackingTypeVisitor extends GraphQLTypeVisitorStub { |
| 283 | + |
| 284 | + Set<String> visitedDefinitions = new HashSet<>(); |
| 285 | + |
| 286 | + @Override |
| 287 | + public TraversalControl visitGraphQLFieldDefinition(GraphQLFieldDefinition node, TraverserContext<GraphQLSchemaElement> context) { |
| 288 | + this.visitedDefinitions.add(node.getName()); |
| 289 | + return TraversalControl.CONTINUE; |
| 290 | + } |
| 291 | + } |
| 292 | + } |
| 293 | + |
226 | 294 |
|
227 | 295 |
|
228 | 296 | private static class ListConnectionAdapter implements ConnectionAdapter {
|
|
0 commit comments