From eae31ca73eb3473c544710955d1dbebc22605bfe Mon Sep 17 00:00:00 2001 From: Tony Ghita Date: Tue, 18 Jan 2022 13:14:45 -0800 Subject: [PATCH] validation: fix bug in maxDepth fragment spread logic (#492) --- graphql_test.go | 29 ++++++++ .../validation/validate_max_depth_test.go | 67 ++++++++++++++++++- internal/validation/validation.go | 24 +++++-- 3 files changed, 113 insertions(+), 7 deletions(-) diff --git a/graphql_test.go b/graphql_test.go index 497a74f3..c8d9593b 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -4297,3 +4297,32 @@ func TestInterfaceImplementingInterface(t *testing.T) { `, }}) } + +func TestCircularFragmentMaxDepth(t *testing.T) { + withMaxDepth := graphql.MustParseSchema(starwars.Schema, &starwars.Resolver{}, graphql.MaxDepth(2)) + gqltesting.RunTests(t, []*gqltesting.Test{ + { + Schema: withMaxDepth, + Query: ` + query { + ...X + } + + fragment X on Query { + ...Y + } + fragment Y on Query { + ...X + } + `, + ExpectedErrors: []*gqlerrors.QueryError{{ + Message: `Cannot spread fragment "X" within itself via Y.`, + Rule: "NoFragmentCycles", + Locations: []gqlerrors.Location{ + {Line: 7, Column: 20}, + {Line: 10, Column: 20}, + }, + }}, + }, + }) +} diff --git a/internal/validation/validate_max_depth_test.go b/internal/validation/validate_max_depth_test.go index f8bfd9a8..613b22a8 100644 --- a/internal/validation/validate_max_depth_test.go +++ b/internal/validation/validate_max_depth_test.go @@ -34,6 +34,7 @@ const ( id: ID! name: String! friends: [Character] + enemies: [Character] appearsIn: [Episode]! } @@ -43,12 +44,15 @@ const ( JEDI } - type Starship {} + type Starship { + id: ID! + } type Human implements Character { id: ID! name: String! friends: [Character] + enemies: [Character] appearsIn: [Episode]! starships: [Starship] totalCredits: Int @@ -58,6 +62,7 @@ const ( id: ID! name: String! friends: [Character] + enemies: [Character] appearsIn: [Episode]! primaryFunction: String }` @@ -304,6 +309,64 @@ func TestMaxDepthFragmentSpreads(t *testing.T) { depth: 6, failure: true, }, + { + name: "spreadAtDifferentDepths", + query: ` + fragment character on Character { + name # depth + 0 + friends { # depth + 0 + name # depth + 1 + } + } + + query laterDepthValidated { + ...character # depth 1 (+1) + enemies { # depth 1 + friends { # depth 2 + ...character # depth 2 (+1), should error! + } + } + } + `, + depth: 2, + failure: true, + }, + { + name: "spreadAtSameDepth", + query: ` + fragment character on Character { + name # depth + 0 + friends { # depth + 0 + name # depth + 1 + } + } + query { + characters { # depth 1 + friends { # depth 2 + ...character # depth 3 (+1) + } + enemies { # depth 2 + ...character # depth 3 (+1) + } + } + } + `, + depth: 4, + }, + { + name: "fragmentCycle", + query: ` + fragment X on Query { ...Y } + fragment Y on Query { ...Z } + fragment Z on Query { ...X } + + query { + ...X + } + `, + depth: 10, + failure: true, + }, } { tc.Run(t, s) } @@ -431,7 +494,7 @@ func TestMaxDepthValidation(t *testing.T) { opc := &opContext{context: context, ops: doc.Operations} - actual := validateMaxDepth(opc, op.Selections, 1) + actual := validateMaxDepth(opc, op.Selections, nil, 1) if actual != tc.expected { t.Errorf("expected %t, actual %t", tc.expected, actual) } diff --git a/internal/validation/validation.go b/internal/validation/validation.go index d456dbe1..e3672638 100644 --- a/internal/validation/validation.go +++ b/internal/validation/validation.go @@ -76,7 +76,7 @@ func Validate(s *types.Schema, doc *types.ExecutableDefinition, variables map[st // Check if max depth is exceeded, if it's set. If max depth is exceeded, // don't continue to validate the document and exit early. - if validateMaxDepth(opc, op.Selections, 1) { + if validateMaxDepth(opc, op.Selections, nil, 1) { return c.errs } @@ -235,13 +235,19 @@ func validateValue(c *opContext, v *types.InputValueDefinition, val interface{}, // validates the query doesn't go deeper than maxDepth (if set). Returns whether // or not query validated max depth to avoid excessive recursion. -func validateMaxDepth(c *opContext, sels []types.Selection, depth int) bool { +// +// The visited map is necessary to ensure that max depth validation does not get stuck in cyclical +// fragment spreads. +func validateMaxDepth(c *opContext, sels []types.Selection, visited map[*types.FragmentDefinition]struct{}, depth int) bool { // maxDepth checking is turned off when maxDepth is 0 if c.maxDepth == 0 { return false } exceededMaxDepth := false + if visited == nil { + visited = map[*types.FragmentDefinition]struct{}{} + } for _, sel := range sels { switch sel := sel.(type) { @@ -251,11 +257,12 @@ func validateMaxDepth(c *opContext, sels []types.Selection, depth int) bool { c.addErr(sel.Alias.Loc, "MaxDepthExceeded", "Field %q has depth %d that exceeds max depth %d", sel.Name.Name, depth, c.maxDepth) continue } - exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, sel.SelectionSet, depth+1) + exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, sel.SelectionSet, visited, depth+1) + case *types.InlineFragment: // Depth is not checked because inline fragments resolve to other fields which are checked. // Depth is not incremented because inline fragments have the same depth as neighboring fields - exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, sel.Selections, depth) + exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, sel.Selections, visited, depth) case *types.FragmentSpread: // Depth is not checked because fragments resolve to other fields which are checked. frag := c.doc.Fragments.Get(sel.Name.Name) @@ -264,8 +271,15 @@ func validateMaxDepth(c *opContext, sels []types.Selection, depth int) bool { c.addErr(sel.Loc, "MaxDepthEvaluationError", "Unknown fragment %q. Unable to evaluate depth.", sel.Name.Name) continue } + + if _, ok := visited[frag]; ok { + // we've already seen this fragment, don't check depth again. + continue + } + visited[frag] = struct{}{} + // Depth is not incremented because fragments have the same depth as surrounding fields - exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, frag.Selections, depth) + exceededMaxDepth = exceededMaxDepth || validateMaxDepth(c, frag.Selections, visited, depth) } }