Skip to content

Commit 92fc2e5

Browse files
dariuszkucsmyrick
authored and
smyrick
committed
fix: directives with arguments should be created per declaration (ExpediaGroup#287)
Directive definitions should only be added to the schema once - while building the schema we are caching the directive definitions and add them after processing all the objects. Our current logic relies on this cache to not-rebuild the directives per each one of its declarations. Directives can have arguments which means that we have to rebuild the directive definition if they accept any arguments.
1 parent 66649df commit 92fc2e5

File tree

9 files changed

+111
-64
lines changed

9 files changed

+111
-64
lines changed

graphql-kotlin-schema-generator/src/main/kotlin/com/expedia/graphql/generator/SchemaGenerator.kt

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ internal class SchemaGenerator(internal val config: SchemaGeneratorConfig) {
6363
builder.additionalType(it)
6464
}
6565

66-
builder.additionalDirectives(state.directives)
66+
builder.additionalDirectives(state.directives.values.toSet())
6767
builder.codeRegistry(codeRegistry.build())
6868
return config.hooks.willBuildSchema(builder).build()
6969
}
@@ -95,15 +95,9 @@ internal class SchemaGenerator(internal val config: SchemaGeneratorConfig) {
9595
internal fun scalarType(type: KType, annotatedAsID: Boolean = false) =
9696
scalarTypeBuilder.scalarType(type, annotatedAsID)
9797

98-
internal fun directives(element: KAnnotatedElement): List<GraphQLDirective> {
99-
val directives = directiveTypeBuilder.directives(element)
100-
state.directives.addAll(directives)
101-
return directives
102-
}
98+
internal fun directives(element: KAnnotatedElement): List<GraphQLDirective> =
99+
directiveTypeBuilder.directives(element)
103100

104-
internal fun fieldDirectives(field: Field): List<GraphQLDirective> {
105-
val directives = directiveTypeBuilder.fieldDirectives(field)
106-
state.directives.addAll(directives)
107-
return directives
108-
}
101+
internal fun fieldDirectives(field: Field): List<GraphQLDirective> =
102+
directiveTypeBuilder.fieldDirectives(field)
109103
}

graphql-kotlin-schema-generator/src/main/kotlin/com/expedia/graphql/generator/state/SchemaGeneratorState.kt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,24 @@ import com.expedia.graphql.directives.DeprecatedDirective
44
import graphql.Directives
55
import graphql.schema.GraphQLDirective
66
import graphql.schema.GraphQLType
7+
import java.util.concurrent.ConcurrentHashMap
78

89
internal class SchemaGeneratorState(supportedPackages: List<String>) {
910
val cache = TypesCache(supportedPackages)
1011
val additionalTypes = mutableSetOf<GraphQLType>()
11-
val directives = mutableSetOf<GraphQLDirective>()
12+
val directives = ConcurrentHashMap<String, GraphQLDirective>()
1213

1314
fun getValidAdditionalTypes(): List<GraphQLType> = additionalTypes.filter { cache.doesNotContainGraphQLType(it) }
1415

1516
init {
1617
// NOTE: @include and @defer query directives are added by graphql-java by default
1718
// adding them explicitly here to keep it consistent with missing deprecated directive
18-
directives.add(Directives.IncludeDirective)
19-
directives.add(Directives.SkipDirective)
19+
directives[Directives.IncludeDirective.name] = Directives.IncludeDirective
20+
directives[Directives.SkipDirective.name] = Directives.SkipDirective
2021

2122
// graphql-kotlin default directives
22-
directives.add(DeprecatedDirective)
23+
// @deprecated directive is a built-in directive that each GraphQL server should provide bu currently it is not added by graphql-java
24+
// see https://github.com/graphql-java/graphql-java/issues/1598
25+
directives[DeprecatedDirective.name] = DeprecatedDirective
2326
}
2427
}

graphql-kotlin-schema-generator/src/main/kotlin/com/expedia/graphql/generator/types/DirectiveBuilder.kt

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,38 +25,48 @@ internal class DirectiveBuilder(generator: SchemaGenerator) : TypeBuilder(genera
2525
.map(this::getDirective)
2626

2727
private fun getDirective(directiveInfo: DirectiveInfo): GraphQLDirective {
28-
29-
val existingDirective = state.directives.find { it.name == directiveInfo.effectiveName }
30-
31-
if (existingDirective != null) {
32-
return existingDirective
28+
val directiveName = directiveInfo.effectiveName
29+
val directive = state.directives.computeIfAbsent(directiveName) {
30+
val builder = GraphQLDirective.newDirective()
31+
.name(directiveInfo.effectiveName)
32+
.description(directiveInfo.directiveAnnotation.description)
33+
34+
directiveInfo.directiveAnnotation.locations.forEach {
35+
builder.validLocation(it)
36+
}
37+
38+
val directiveClass = directiveInfo.directive.annotationClass
39+
directiveClass.getValidProperties(config.hooks).forEach { prop ->
40+
val propertyName = prop.name
41+
val value = prop.call(directiveInfo.directive)
42+
val type = graphQLTypeOf(prop.returnType)
43+
44+
val argument = GraphQLArgument.newArgument()
45+
.name(propertyName)
46+
.value(value)
47+
.type(type.safeCast())
48+
.build()
49+
50+
builder.argument(argument)
51+
}
52+
builder.build()
3353
}
3454

35-
val directiveClass = directiveInfo.directive.annotationClass
36-
37-
val builder = GraphQLDirective.newDirective()
38-
.name(directiveInfo.effectiveName)
39-
.description(directiveInfo.directiveAnnotation.description)
40-
41-
directiveInfo.directiveAnnotation.locations.forEach {
42-
builder.validLocation(it)
55+
return if (directive.arguments.isNotEmpty()) {
56+
// update args for this instance
57+
val builder = GraphQLDirective.newDirective(directive)
58+
directiveInfo.directive.annotationClass.getValidProperties(config.hooks).forEach { prop ->
59+
val defaultArgument = directive.getArgument(prop.name)
60+
val value = prop.call(directiveInfo.directive)
61+
val argument = GraphQLArgument.newArgument(defaultArgument)
62+
.value(value)
63+
.build()
64+
builder.argument(argument)
65+
}
66+
builder.build()
67+
} else {
68+
directive
4369
}
44-
45-
directiveClass.getValidProperties(config.hooks).forEach { prop ->
46-
val propertyName = prop.name
47-
val value = prop.call(directiveInfo.directive)
48-
val type = graphQLTypeOf(prop.returnType)
49-
50-
val argument = GraphQLArgument.newArgument()
51-
.name(propertyName)
52-
.value(value)
53-
.type(type.safeCast())
54-
.build()
55-
56-
builder.argument(argument)
57-
}
58-
59-
return builder.build()
6070
}
6171
}
6272

graphql-kotlin-schema-generator/src/test/kotlin/com/expedia/graphql/generator/types/DirectiveBuilderTest.kt

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import com.expedia.graphql.generator.SchemaGenerator
77
import com.expedia.graphql.generator.extensions.isTrue
88
import com.expedia.graphql.getTestSchemaConfigWithMockedDirectives
99
import graphql.Directives
10+
import org.junit.jupiter.api.BeforeEach
1011
import org.junit.jupiter.api.Test
1112
import kotlin.reflect.KClass
1213
import kotlin.test.assertEquals
@@ -48,14 +49,22 @@ internal class DirectiveBuilderTest {
4849
@DirectiveWithString(string = "foo")
4950
fun directiveWithString(string: String) = string
5051

52+
@DirectiveWithString(string = "bar")
53+
fun directiveWithAnotherString(string: String) = string
54+
5155
@DirectiveWithEnum(type = Type.TWO)
5256
fun directiveWithEnum(string: String) = string
5357

5458
@DirectiveWithClass(kclass = Type::class)
5559
fun directiveWithClass(string: String) = string
5660
}
5761

58-
private val basicGenerator = SchemaGenerator(getTestSchemaConfigWithMockedDirectives())
62+
private lateinit var basicGenerator: SchemaGenerator
63+
64+
@BeforeEach
65+
fun setUp() {
66+
basicGenerator = SchemaGenerator(getTestSchemaConfigWithMockedDirectives())
67+
}
5968

6069
@Test
6170
fun `no annotation`() {
@@ -88,14 +97,17 @@ internal class DirectiveBuilderTest {
8897
}
8998

9099
@Test
91-
fun `directives are not duplicated in the schema`() {
100+
fun `directives are only added to the schema once`() {
92101
val initialCount = basicGenerator.state.directives.size
93-
assertTrue(basicGenerator.state.directives.contains(Directives.IncludeDirective))
94-
assertTrue(basicGenerator.state.directives.contains(Directives.SkipDirective))
95-
assertTrue(basicGenerator.state.directives.contains(DeprecatedDirective))
96-
97-
basicGenerator.directives(MyClass::simpleDirective)
98-
basicGenerator.directives(MyClass::simpleDirective)
102+
assertTrue(basicGenerator.state.directives.containsKey(Directives.IncludeDirective.name))
103+
assertTrue(basicGenerator.state.directives.containsKey(Directives.SkipDirective.name))
104+
assertTrue(basicGenerator.state.directives.containsKey(DeprecatedDirective.name))
105+
106+
val firstInvocation = basicGenerator.directives(MyClass::simpleDirective)
107+
assertEquals(1, firstInvocation.size)
108+
val secondInvocation = basicGenerator.directives(MyClass::simpleDirective)
109+
assertEquals(1, secondInvocation.size)
110+
assertEquals(firstInvocation.first(), secondInvocation.first())
99111
assertEquals(initialCount + 1, basicGenerator.state.directives.size)
100112
}
101113

@@ -118,4 +130,22 @@ internal class DirectiveBuilderTest {
118130

119131
assertEquals(0, directives.size)
120132
}
133+
134+
@Test
135+
fun `directives are created per each declaration`() {
136+
val initialCount = basicGenerator.state.directives.size
137+
val directivesOnFirstField = basicGenerator.directives(MyClass::directiveWithString)
138+
val directivesOnSecondField = basicGenerator.directives(MyClass::directiveWithAnotherString)
139+
assertEquals(expected = 1, actual = directivesOnFirstField.size)
140+
assertEquals(expected = 1, actual = directivesOnSecondField.size)
141+
142+
val firstDirective = directivesOnFirstField.first()
143+
val seconDirective = directivesOnSecondField.first()
144+
assertEquals("directiveWithString", firstDirective.name)
145+
assertEquals("directiveWithString", seconDirective.name)
146+
assertEquals("foo", firstDirective.getArgument("string")?.value)
147+
assertEquals("bar", seconDirective.getArgument("string")?.value)
148+
149+
assertEquals(initialCount + 1, basicGenerator.state.directives.size)
150+
}
121151
}

graphql-kotlin-schema-generator/src/test/kotlin/com/expedia/graphql/generator/types/TypeTestHelper.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,16 @@ internal open class TypeTestHelper {
101101
directiveBuilder = spyk(DirectiveBuilder(generator))
102102
every { generator.directives(any()) } answers {
103103
val directives = directiveBuilder!!.directives(it.invocation.args[0] as KAnnotatedElement)
104-
state.directives.addAll(directives)
104+
for (directive in directives) {
105+
state.directives[directive.name] = directive
106+
}
105107
directives
106108
}
107109
every { generator.fieldDirectives(any()) } answers {
108110
val directives = directiveBuilder!!.fieldDirectives(it.invocation.args[0] as Field)
109-
state.directives.addAll(directives)
111+
for (directive in directives) {
112+
state.directives[directive.name] = directive
113+
}
110114
directives
111115
}
112116

graphql-kotlin-spring-example/src/main/kotlin/com/expedia/graphql/sample/directives/CustomDirectiveWiringFactory.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ import kotlin.reflect.KClass
1010
class CustomDirectiveWiringFactory : KotlinDirectiveWiringFactory(manualWiring = mapOf<String, KotlinSchemaDirectiveWiring>("lowercase" to LowercaseSchemaDirectiveWiring())) {
1111

1212
private val stringEvalDirectiveWiring = StringEvalSchemaDirectiveWiring()
13-
private val caleOnlyDirectiveWiring = CakeOnlySchemaDirectiveWiring()
13+
private val caleOnlyDirectiveWiring = SpecificValueOnlySchemaDirectiveWiring()
1414

1515
override fun getSchemaDirectiveWiring(environment: KotlinSchemaDirectiveEnvironment<GraphQLDirectiveContainer>): KotlinSchemaDirectiveWiring? = when {
1616
environment.directive.name == getDirectiveName(StringEval::class) -> stringEvalDirectiveWiring
17-
environment.directive.name == getDirectiveName(CakeOnly::class) -> caleOnlyDirectiveWiring
17+
environment.directive.name == getDirectiveName(SpecificValueOnly::class) -> caleOnlyDirectiveWiring
1818
else -> null
1919
}
2020
}

graphql-kotlin-spring-example/src/main/kotlin/com/expedia/graphql/sample/directives/CakeOnly.kt renamed to graphql-kotlin-spring-example/src/main/kotlin/com/expedia/graphql/sample/directives/SpecificValueOnly.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ package com.expedia.graphql.sample.directives
22

33
import com.expedia.graphql.annotations.GraphQLDirective
44

5-
@GraphQLDirective(description = "This validates inputted string is equal to cake")
6-
annotation class CakeOnly
5+
@GraphQLDirective(description = "This validates inputted string is equal to specified argument")
6+
annotation class SpecificValueOnly(val value: String)
Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,19 @@ import com.expedia.graphql.directives.KotlinSchemaDirectiveWiring
55
import graphql.schema.DataFetcher
66
import graphql.schema.GraphQLFieldDefinition
77

8-
class CakeOnlySchemaDirectiveWiring : KotlinSchemaDirectiveWiring {
8+
class SpecificValueOnlySchemaDirectiveWiring : KotlinSchemaDirectiveWiring {
99

1010
@Throws(RuntimeException::class)
1111
override fun onField(environment: KotlinFieldDirectiveEnvironment): GraphQLFieldDefinition {
1212
val field = environment.element
1313
val originalDataFetcher: DataFetcher<Any> = environment.getDataFetcher()
1414

15+
val supportedValue = environment.directive.getArgument("value")?.value?.toString() ?: ""
16+
1517
val cakeOnlyFetcher = DataFetcher<Any> { dataEnv ->
1618
val strArg: String? = dataEnv.getArgument(environment.element.arguments[0].name) as String?
17-
if (!"cake".equals(other = strArg, ignoreCase = true)) {
18-
throw RuntimeException("The cake is a lie!")
19+
if (!supportedValue.equals(other = strArg, ignoreCase = true)) {
20+
throw RuntimeException("Unsupported value, expected=$supportedValue actual=$strArg")
1921
}
2022
originalDataFetcher.get(dataEnv)
2123
}
Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package com.expedia.graphql.sample.query
22

33
import com.expedia.graphql.annotations.GraphQLDescription
4-
import com.expedia.graphql.sample.directives.CakeOnly
54
import com.expedia.graphql.sample.directives.LowercaseDirective
5+
import com.expedia.graphql.sample.directives.SpecificValueOnly
66
import com.expedia.graphql.sample.directives.StringEval
77
import org.springframework.stereotype.Component
88

@@ -13,10 +13,14 @@ class CustomDirectiveQuery : Query {
1313
fun justWhisper(@StringEval(default = "default string", lowerCase = true) msg: String?): String? = msg
1414

1515
@GraphQLDescription("This will only accept 'Cake' as input")
16-
@CakeOnly
16+
@SpecificValueOnly("cake")
1717
fun onlyCake(msg: String): String = "<3"
1818

19+
@GraphQLDescription("This will only accept 'IceCream' as input")
20+
@SpecificValueOnly("icecream")
21+
fun onlyIceCream(msg: String): String = "<3"
22+
1923
@GraphQLDescription("Returns message modified by the manually wired directive to force lowercase")
2024
@LowercaseDirective
2125
fun forceLowercaseEcho(msg: String) = msg
22-
}
26+
}

0 commit comments

Comments
 (0)