Skip to content

Allow additional types to be added with custom SchemaGenerator #587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

package com.expediagroup.graphql.federation.validation

import com.expediagroup.graphql.extensions.unwrapType
import com.expediagroup.graphql.federation.directives.PROVIDES_DIRECTIVE_NAME
import com.expediagroup.graphql.federation.extensions.isExtendedType
import graphql.schema.GraphQLFieldDefinition
import graphql.schema.GraphQLObjectType
import graphql.schema.GraphQLTypeReference
import graphql.schema.GraphQLTypeUtil

// [OK] @provides on base type references valid @external fields on @extend object
// [ERROR] @provides on base type references local object fields
Expand All @@ -30,7 +30,7 @@ import graphql.schema.GraphQLTypeUtil
// [OK] @provides references list of valid @extend objects
// [ERROR] @provides references @external list field
// [ERROR] @provides references @external interface field
internal fun validateProvidesDirective(federatedType: String, field: GraphQLFieldDefinition): List<String> = when (val returnType = GraphQLTypeUtil.unwrapType(field.type).last()) {
internal fun validateProvidesDirective(federatedType: String, field: GraphQLFieldDefinition): List<String> = when (val returnType = field.type.unwrapType()) {
is GraphQLObjectType -> {
if (!returnType.isExtendedType()) {
listOf("@provides directive is specified on a $federatedType.${field.name} field references local object")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.expediagroup.graphql.federation.types

import graphql.schema.GraphQLTypeUtil
import com.expediagroup.graphql.extensions.unwrapType
import graphql.schema.GraphQLUnionType
import org.junit.jupiter.api.Test
import kotlin.test.assertEquals
Expand All @@ -41,7 +41,7 @@ internal class EntityTest {
assertFalse(result.description.isNullOrEmpty())
assertEquals(expected = 1, actual = result.arguments.size)

val graphQLUnionType = GraphQLTypeUtil.unwrapType(result.type).last() as? GraphQLUnionType
val graphQLUnionType = result.type.unwrapType() as? GraphQLUnionType

assertNotNull(graphQLUnionType)
assertEquals(expected = "_Entity", actual = graphQLUnionType.name)
Expand All @@ -52,7 +52,7 @@ internal class EntityTest {
@Test
fun `generateEntityFieldDefinition should return a valid type on a multiple values`() {
val result = generateEntityFieldDefinition(setOf("MyType", "MySecondType"))
val graphQLUnionType = GraphQLTypeUtil.unwrapType(result.type).last() as? GraphQLUnionType
val graphQLUnionType = result.type.unwrapType() as? GraphQLUnionType

assertNotNull(graphQLUnionType)
assertEquals(expected = 2, actual = graphQLUnionType.types.size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,10 @@ import graphql.schema.GraphQLTypeUtil
*/
val GraphQLType.deepName: String
get() = GraphQLTypeUtil.simplePrint(this)

/**
* Unwrap the type of all layers and return the last element.
* This includes GraphQLNonNull and GraphQLList.
* If the type is not wrapped, it will just be returned.
*/
fun GraphQLType.unwrapType(): GraphQLType = GraphQLTypeUtil.unwrapType(this).last()
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import graphql.schema.GraphQLSchema
import graphql.schema.GraphQLType
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.full.createType

/**
Expand All @@ -45,7 +46,7 @@ open class SchemaGenerator(internal val config: SchemaGeneratorConfig) {
internal val classScanner = ClassScanner(config.supportedPackages)
internal val cache = TypesCache(config.supportedPackages)
internal val codeRegistry = GraphQLCodeRegistry.newCodeRegistry()
internal val additionalTypes = mutableSetOf<GraphQLType>()
internal val additionalTypes = mutableSetOf<KType>()
internal val directives = ConcurrentHashMap<String, GraphQLDirective>()

init {
Expand All @@ -68,16 +69,12 @@ open class SchemaGenerator(internal val config: SchemaGeneratorConfig) {
mutations: List<TopLevelObject> = emptyList(),
subscriptions: List<TopLevelObject> = emptyList()
): GraphQLSchema {

val builder = GraphQLSchema.newSchema()
builder.query(generateQueries(this, queries))
builder.mutation(generateMutations(this, mutations))
builder.subscription(generateSubscriptions(this, subscriptions))

// add unreferenced interface implementations
additionalTypes.forEach {
builder.additionalType(it)
}

builder.additionalTypes(generateAdditionalTypes(additionalTypes))
builder.additionalDirectives(directives.values.toSet())
builder.codeRegistry(codeRegistry.build())

Expand All @@ -94,10 +91,10 @@ open class SchemaGenerator(internal val config: SchemaGeneratorConfig) {
* This is helpful for things like federation or combining external schemas
*/
protected fun addAdditionalTypesWithAnnotation(annotation: KClass<*>) {
classScanner.getClassesWithAnnotation(annotation)
.map { generateGraphQLType(this, it.createType()) }
.forEach {
additionalTypes.add(it)
}
classScanner.getClassesWithAnnotation(annotation).forEach {
additionalTypes.add(it.createType())
}
}

private fun generateAdditionalTypes(additionalTypes: Set<KType>): Set<GraphQLType> = additionalTypes.map { generateGraphQLType(this, it) }.toSet()
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.expediagroup.graphql.generator.types

import com.expediagroup.graphql.extensions.unwrapType
import com.expediagroup.graphql.generator.SchemaGenerator
import com.expediagroup.graphql.generator.extensions.getKClass
import com.expediagroup.graphql.generator.extensions.isEnum
Expand All @@ -26,7 +27,6 @@ import com.expediagroup.graphql.generator.extensions.wrapInNonNull
import com.expediagroup.graphql.generator.state.TypesCacheKey
import graphql.schema.GraphQLType
import graphql.schema.GraphQLTypeReference
import graphql.schema.GraphQLTypeUtil
import kotlin.reflect.KClass
import kotlin.reflect.KType

Expand All @@ -40,7 +40,7 @@ internal fun generateGraphQLType(generator: SchemaGenerator, type: KType, inputT
?: objectFromReflection(generator, type, inputType)

// Do not call the hook on GraphQLTypeReference as we have not generated the type yet
val unwrappedType = GraphQLTypeUtil.unwrapType(graphQLType).lastElement()
val unwrappedType = graphQLType.unwrapType()
val typeWithNullability = graphQLType.wrapInNonNull(type)
if (unwrappedType !is GraphQLTypeReference) {
return generator.config.hooks.didGenerateGraphQLType(type, typeWithNullability)
Expand All @@ -58,9 +58,11 @@ private fun objectFromReflection(generator: SchemaGenerator, type: KType, inputT
}

val kClass = type.getKClass()
val graphQLType = generator.cache.buildIfNotUnderConstruction(kClass, inputType) { getGraphQLType(generator, kClass, inputType, type) }

return generator.config.hooks.willAddGraphQLTypeToSchema(type, graphQLType)
Copy link
Contributor Author

@smyrick smyrick Feb 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not want to call the hook with type references. This is now required for the new format

return generator.cache.buildIfNotUnderConstruction(kClass, inputType) {
val graphQLType = getGraphQLType(generator, kClass, inputType, type)
generator.config.hooks.willAddGraphQLTypeToSchema(type, graphQLType)
}
}

private fun getGraphQLType(generator: SchemaGenerator, kClass: KClass<*>, inputType: Boolean, type: KType): GraphQLType = when {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ import com.expediagroup.graphql.generator.extensions.getValidSuperclasses
import com.expediagroup.graphql.generator.extensions.safeCast
import graphql.TypeResolutionEnvironment
import graphql.schema.GraphQLInterfaceType
import graphql.schema.GraphQLTypeReference
import graphql.schema.GraphQLTypeUtil
import kotlin.reflect.KClass
import kotlin.reflect.full.createType

Expand All @@ -53,14 +51,7 @@ internal fun generateInterface(generator: SchemaGenerator, kClass: KClass<*>): G
.forEach { builder.field(generateFunction(generator, it, kClass.getSimpleName(), null, abstract = true)) }

generator.classScanner.getSubTypesOf(kClass)
.map { generateGraphQLType(generator, it.createType()) }
.forEach {
// Do not add objects currently under construction to the additional types
val unwrappedType = GraphQLTypeUtil.unwrapType(it).last()
if (unwrappedType !is GraphQLTypeReference) {
generator.additionalTypes.add(it)
}
}
.forEach { generator.additionalTypes.add(it.createType()) }

val interfaceType = builder.build()
generator.codeRegistry.typeResolver(interfaceType) { env: TypeResolutionEnvironment -> env.schema.getObjectType(env.getObject<Any>().javaClass.kotlin.getSimpleName()) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.expediagroup.graphql.generator.types

import com.expediagroup.graphql.extensions.unwrapType
import com.expediagroup.graphql.generator.SchemaGenerator
import com.expediagroup.graphql.generator.extensions.getGraphQLDescription
import com.expediagroup.graphql.generator.extensions.getSimpleName
Expand All @@ -26,7 +27,6 @@ import com.expediagroup.graphql.generator.extensions.safeCast
import graphql.schema.GraphQLInterfaceType
import graphql.schema.GraphQLObjectType
import graphql.schema.GraphQLTypeReference
import graphql.schema.GraphQLTypeUtil
import kotlin.reflect.KClass
import kotlin.reflect.full.createType

Expand All @@ -44,7 +44,7 @@ internal fun generateObject(generator: SchemaGenerator, kClass: KClass<*>): Grap
kClass.getValidSuperclasses(generator.config.hooks)
.map { generateGraphQLType(generator, it.createType()) }
.forEach {
when (val unwrappedType = GraphQLTypeUtil.unwrapType(it).last()) {
when (val unwrappedType = it.unwrapType()) {
is GraphQLTypeReference -> builder.withInterface(unwrappedType)
is GraphQLInterfaceType -> builder.withInterface(unwrappedType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

package com.expediagroup.graphql.generator.types

import com.expediagroup.graphql.extensions.unwrapType
import com.expediagroup.graphql.generator.SchemaGenerator
import com.expediagroup.graphql.generator.extensions.getGraphQLDescription
import com.expediagroup.graphql.generator.extensions.getSimpleName
import com.expediagroup.graphql.generator.extensions.safeCast
import graphql.TypeResolutionEnvironment
import graphql.schema.GraphQLObjectType
import graphql.schema.GraphQLTypeReference
import graphql.schema.GraphQLTypeUtil
import graphql.schema.GraphQLUnionType
import kotlin.reflect.KClass
import kotlin.reflect.full.createType
Expand All @@ -40,7 +40,7 @@ internal fun generateUnion(generator: SchemaGenerator, kClass: KClass<*>): Graph
generator.classScanner.getSubTypesOf(kClass)
.map { generateGraphQLType(generator, it.createType()) }
.forEach {
when (val unwrappedType = GraphQLTypeUtil.unwrapType(it).last()) {
when (val unwrappedType = it.unwrapType()) {
is GraphQLTypeReference -> builder.possibleType(unwrappedType)
is GraphQLObjectType -> builder.possibleType(unwrappedType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ internal class FunctionDataFetcherTest {

fun throwException() { throw GraphQLException("Test Exception") }

suspend fun suspendThrow(value: String?): String = coroutineScope {
value ?: throw GraphQLException("Suspended Exception")
suspend fun suspendThrow(): String = coroutineScope<String> {
throw GraphQLException("Suspended Exception")
}

@GraphQLName("myCustomField")
Expand Down Expand Up @@ -155,7 +155,6 @@ internal class FunctionDataFetcherTest {
fun `suspendThrow throws exception when resolved`() {
val dataFetcher = FunctionDataFetcher(target = MyClass(), fn = MyClass::suspendThrow)
val mockEnvironmet: DataFetchingEnvironment = mockk()
every { mockEnvironmet.arguments } returns mapOf("value" to null)

try {
val result = dataFetcher.get(mockEnvironmet)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import io.mockk.mockk
import org.junit.jupiter.api.Test
import kotlin.test.assertEquals

internal class DeepNameKtTest {
internal class GraphqlTypeExtensionsKtTest {

private val basicType = mockk<GraphQLNamedType> {
every { name } returns "BasicType"
Expand Down Expand Up @@ -52,4 +52,27 @@ internal class DeepNameKtTest {
val complicated = GraphQLNonNull(GraphQLList(GraphQLNonNull(basicType)))
assertEquals(expected = "[BasicType!]!", actual = complicated.deepName)
}

@Test
fun `unwrapType works on basic types that are not wrapped`() {
assertEquals("BasicType", basicType.unwrapType().deepName)
}

@Test
fun `unwrapType works on non null`() {
val nonNull = GraphQLNonNull.nonNull(basicType)
assertEquals("BasicType", nonNull.unwrapType().deepName)
}

@Test
fun `unwrapType works on lists`() {
val graphQLList = GraphQLList.list(basicType)
assertEquals("BasicType", graphQLList.unwrapType().deepName)
}

@Test
fun `unwrapType works on multiple layers`() {
val graphQLList = GraphQLNonNull.nonNull(GraphQLList.list(GraphQLNonNull.nonNull(basicType)))
assertEquals("BasicType", graphQLList.unwrapType().deepName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.expediagroup.graphql.generator.filters

import com.expediagroup.graphql.annotations.GraphQLIgnore
import org.junit.jupiter.api.Test
import kotlin.reflect.KClass
import kotlin.test.assertFalse
Expand All @@ -35,12 +36,18 @@ class SuperclassFiltersKtTest {
fun internal(): String
}

@GraphQLIgnore
interface IgnoredInterface {
fun public(): String
}

@Test
fun superclassFilters() {
assertTrue(isValidSuperclass(Interface::class))
assertFalse(isValidSuperclass(Union::class))
assertFalse(isValidSuperclass(NonPublic::class))
assertFalse(isValidSuperclass(Class::class))
assertFalse(isValidSuperclass(IgnoredInterface::class))
}

private fun isValidSuperclass(kClass: KClass<*>): Boolean = superclassFilters.all { it(kClass) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ import com.expediagroup.graphql.generator.extensions.getSimpleName
import com.expediagroup.graphql.getTestSchemaConfigWithHooks
import com.expediagroup.graphql.testSchemaConfig
import com.expediagroup.graphql.toSchema
import graphql.language.StringValue
import graphql.schema.Coercing
import graphql.schema.GraphQLFieldDefinition
import graphql.schema.GraphQLInterfaceType
import graphql.schema.GraphQLObjectType
import graphql.schema.GraphQLScalarType
import graphql.schema.GraphQLSchema
import graphql.schema.GraphQLType
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import java.util.UUID
import kotlin.random.Random
import kotlin.reflect.KClass
import kotlin.reflect.KFunction
Expand Down Expand Up @@ -251,10 +255,40 @@ class SchemaGeneratorHooksTest {
assertEquals(expected = "SomeData", actual = hooks.willResolveMonad(type).getSimpleName())
}

@Test
fun `willGenerateGraphQLType can override to provide a custom type`() {
class MockSchemaGeneratorHooks : SchemaGeneratorHooks {
var hookCalled = false

override fun willGenerateGraphQLType(type: KType): GraphQLType? {
hookCalled = true

return when (type.classifier as? KClass<*>) {
UUID::class -> graphqlUUIDType
else -> null
}
}
}

val hooks = MockSchemaGeneratorHooks()
val schema = toSchema(
queries = listOf(TopLevelObject(CustomTypesQuery())),
config = getTestSchemaConfigWithHooks(hooks)
)

assertTrue(hooks.hookCalled)
val graphQLType = assertNotNull(schema.getType("UUID"))
assertTrue(graphQLType is GraphQLScalarType)
}

class TestQuery {
fun query(): SomeData = SomeData("someData", 0)
}

class CustomTypesQuery {
fun uuid(): UUID = UUID.randomUUID()
}

class TestInterfaceQuery {
fun randomQuery(): RandomData = if (Random.nextBoolean()) {
SomeData("random", 1)
Expand Down Expand Up @@ -294,4 +328,21 @@ class SchemaGeneratorHooksTest {
}

class EmptyImplementation(override val id: String) : EmptyInterface

private val graphqlUUIDType = GraphQLScalarType.newScalar()
.name("UUID")
.description("A type representing a formatted java.util.UUID")
.coercing(UUIDCoercing)
.build()

private object UUIDCoercing : Coercing<UUID, String> {
override fun parseValue(input: Any?): UUID = UUID.fromString(serialize(input))

override fun parseLiteral(input: Any?): UUID? {
val uuidString = (input as? StringValue)?.value
return UUID.fromString(uuidString)
}

override fun serialize(dataFetcherResult: Any?): String = dataFetcherResult.toString()
}
}