diff --git a/requested_fields.go b/requested_fields.go index 7536377..51099de 100644 --- a/requested_fields.go +++ b/requested_fields.go @@ -31,3 +31,27 @@ func RequestedForAt(ctx context.Context, resolver interface{}, pathToAppend stri return tree[pathTree] } + +// RequestedForContaining returns all requested fields for +//some path from a reference Resolver. +func RequestedForContaining(ctx context.Context, resolverString string) []string { + tree := ctx.Value(ContextKey).(map[string][]string) + + var src = make(map[string]interface{}) + for key, value := range tree { + if strings.Contains(key, resolverString) { + for _, val := range value { + src[val] = nil + } + } + } + + var i = 0 + var srcList = make([]string, len(src)) + for k := range src { + srcList[i] = k + i = i + 1 + } + + return srcList +} diff --git a/requested_fields_test.go b/requested_fields_test.go index bda3d28..bc71c5d 100644 --- a/requested_fields_test.go +++ b/requested_fields_test.go @@ -2,8 +2,10 @@ package fields import ( "context" - "github.com/stretchr/testify/assert" + "sort" "testing" + + "github.com/stretchr/testify/assert" ) var graphql_query_products string = ` @@ -86,3 +88,35 @@ func TestRequestedFieldsForUser(t *testing.T) { assert.Equal(t, expected_fields, requested_fields) } + +var graphql_query_user_nested string = ` +{ + user(id: 3) { + id + name + user { + id + age + height + } + } +} +` + +func TestRequestedFieldsForContainingUser(t *testing.T) { + query_resolver := &QueryResolver{} + + user_resolver := &UserResolver{} + user_resolver.Field.SetParent(query_resolver) + + ctx := context.WithValue(context.Background(), + ContextKey, BuildTree(graphql_query_user_nested, Variables{})) + + expected_fields := []string{"id", "name", "age", "height", "user"} + requested_fields := RequestedForContaining(ctx, "user") + + sort.Slice(expected_fields, func(i, j int) bool { return expected_fields[i] < expected_fields[j] }) + sort.Slice(requested_fields, func(i, j int) bool { return requested_fields[i] < requested_fields[j] }) + + assert.Equal(t, expected_fields, requested_fields) +}