Skip to content

feat(batching): Exclude operations that failed pre-executions (#1942) #1946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2022 Expedia, Inc
* Copyright 2024 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,9 +37,8 @@ import org.dataloader.DataLoader
class DataLoaderLevelDispatchedInstrumentation : AbstractExecutionLevelDispatchedInstrumentation() {
override fun getOnLevelDispatchedCallback(
parameters: ExecutionLevelDispatchedInstrumentationParameters
): OnLevelDispatchedCallback = { _, executions: List<ExecutionInput> ->
executions
.getOrNull(0)
): OnLevelDispatchedCallback = { _, _ ->
parameters.executionContext.executionInput
?.dataLoaderRegistry
?.dispatchAll()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2022 Expedia, Inc
* Copyright 2024 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,24 +22,26 @@ import com.expediagroup.graphql.dataloader.instrumentation.level.state.Level
import graphql.ExecutionInput
import graphql.ExecutionResult
import graphql.execution.ExecutionContext
import graphql.execution.ExecutionId
import graphql.execution.instrumentation.ExecutionStrategyInstrumentationContext
import graphql.execution.instrumentation.Instrumentation
import graphql.execution.instrumentation.InstrumentationContext
import graphql.execution.instrumentation.InstrumentationState
import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters
import graphql.execution.instrumentation.SimplePerformantInstrumentation
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters
import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters
import graphql.schema.DataFetcher

/**
* Represents the signature of a callback that will be executed when a [Level] is dispatched
*/
internal typealias OnLevelDispatchedCallback = (Level, List<ExecutionInput>) -> Unit
internal typealias OnLevelDispatchedCallback = (Level, List<ExecutionId>) -> Unit
/**
* Custom GraphQL [graphql.execution.instrumentation.Instrumentation] that calculate the state of executions
* of all queries sharing the same GraphQLContext map
*/
abstract class AbstractExecutionLevelDispatchedInstrumentation : Instrumentation {
abstract class AbstractExecutionLevelDispatchedInstrumentation : SimplePerformantInstrumentation() {
/**
* This is invoked each time instrumentation attempts to calculate a level dispatched state, this can be called from either
* `beginFieldField` or `beginExecutionStrategy`.
Expand All @@ -52,13 +54,13 @@ abstract class AbstractExecutionLevelDispatchedInstrumentation : Instrumentation
parameters: ExecutionLevelDispatchedInstrumentationParameters
): OnLevelDispatchedCallback

override fun beginExecuteOperation(
parameters: InstrumentationExecuteOperationParameters,
override fun beginExecution(
parameters: InstrumentationExecutionParameters,
state: InstrumentationState?
): InstrumentationContext<ExecutionResult>? =
parameters.executionContext.takeUnless(ExecutionContext::isMutation)
parameters.executionInput
?.graphQLContext?.get<ExecutionLevelDispatchedState>(ExecutionLevelDispatchedState::class)
?.beginExecuteOperation(parameters)
?.beginExecution(parameters)

override fun beginExecutionStrategy(
parameters: InstrumentationExecutionStrategyParameters,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 Expedia, Inc
* Copyright 2024 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,38 +20,63 @@ import com.expediagroup.graphql.dataloader.instrumentation.extensions.getExpecte
import com.expediagroup.graphql.dataloader.instrumentation.level.execution.OnLevelDispatchedCallback
import graphql.ExecutionInput
import graphql.ExecutionResult
import graphql.execution.ExecutionId
import graphql.execution.FieldValueInfo
import graphql.execution.instrumentation.ExecutionStrategyInstrumentationContext
import graphql.execution.instrumentation.InstrumentationContext
import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters
import graphql.execution.instrumentation.SimpleInstrumentationContext
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters
import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters
import graphql.schema.DataFetcher
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicReference

/**
* Orchestrate the [ExecutionBatchState] of all [ExecutionInput] sharing the same graphQLContext map,
* when a certain state is reached will invoke [OnLevelDispatchedCallback]
*/
class ExecutionLevelDispatchedState(
private val totalExecutions: Int
totalOperations: Int
) {
val executions = ConcurrentHashMap<ExecutionInput, ExecutionBatchState>()
private val totalExecutions: AtomicReference<Int> = AtomicReference(totalOperations)
val executions = ConcurrentHashMap<ExecutionId, ExecutionBatchState>()

/**
* Remove an [ExecutionBatchState] from the state in case operation does not qualify for execution,
* for example:
* parsing, validation, execution errors
* persisted query errors
*/
private fun removeExecution(executionId: ExecutionId) {
if (executions.containsKey(executionId)) {
executions.remove(executionId)
totalExecutions.set(totalExecutions.get() - 1)
}
}

/**
* Initialize the [ExecutionBatchState] of this [ExecutionInput]
*
* @param parameters contains information of which [ExecutionInput] will start his execution
* @return a nullable [InstrumentationContext]
*/
fun beginExecuteOperation(
parameters: InstrumentationExecuteOperationParameters
): InstrumentationContext<ExecutionResult>? {
executions.computeIfAbsent(parameters.executionContext.executionInput) {
fun beginExecution(
parameters: InstrumentationExecutionParameters
): InstrumentationContext<ExecutionResult> {
executions.computeIfAbsent(parameters.executionInput.executionId) {
ExecutionBatchState()
}
return null
return object : SimpleInstrumentationContext<ExecutionResult>() {
override fun onCompleted(result: ExecutionResult?, t: Throwable?) {
result?.let {
if (result.errors.size > 0) {
removeExecution(parameters.executionInput.executionId)
}
}
}
}
}

/**
Expand All @@ -64,11 +89,11 @@ class ExecutionLevelDispatchedState(
parameters: InstrumentationExecutionStrategyParameters,
onLevelDispatched: OnLevelDispatchedCallback
): ExecutionStrategyInstrumentationContext {
val executionInput = parameters.executionContext.executionInput
val executionId = parameters.executionContext.executionInput.executionId
val level = Level(parameters.executionStrategyParameters.path.level + 1)
val fieldCount = parameters.executionStrategyParameters.fields.size()

executions.computeIfPresent(executionInput) { _, executionState ->
executions.computeIfPresent(executionId) { _, executionState ->
executionState.also {
it.initializeLevelStateIfNeeded(level)
it.increaseExpectedFetches(level, fieldCount)
Expand All @@ -86,7 +111,7 @@ class ExecutionLevelDispatchedState(
override fun onFieldValuesInfo(fieldValueInfoList: List<FieldValueInfo>) {
val nextLevel = level.next()

executions.computeIfPresent(executionInput) { _, executionState ->
executions.computeIfPresent(executionId) { _, executionState ->
executionState.also {
it.increaseOnFieldValueInfos(level)
it.increaseExpectedExecutionStrategies(
Expand All @@ -104,7 +129,7 @@ class ExecutionLevelDispatchedState(
}

override fun onFieldValuesException() {
executions.computeIfPresent(executionInput) { _, executionState ->
executions.computeIfPresent(executionId) { _, executionState ->
executionState.also {
it.increaseOnFieldValueInfos(level)
}
Expand All @@ -123,14 +148,13 @@ class ExecutionLevelDispatchedState(
parameters: InstrumentationFieldFetchParameters,
onLevelDispatched: OnLevelDispatchedCallback
): InstrumentationContext<Any> {
val executionInput = parameters.executionContext.executionInput
val executionId = parameters.executionContext.executionInput.executionId
val path = parameters.executionStepInfo.path
val level = Level(path.level)

return object : InstrumentationContext<Any> {
return object : SimpleInstrumentationContext<Any>() {
override fun onDispatched(result: CompletableFuture<Any?>) {

executions.computeIfPresent(executionInput) { _, executionState ->
executions.computeIfPresent(executionId) { _, executionState ->
executionState.also { it.increaseDispatchedFetches(level) }
}

Expand All @@ -140,9 +164,6 @@ class ExecutionLevelDispatchedState(
executions.forEach { (_, executionState) -> executionState.completeDataFetchers(level) }
}
}

override fun onCompleted(result: Any?, t: Throwable?) {
}
}
}

Expand All @@ -161,7 +182,7 @@ class ExecutionLevelDispatchedState(
parameters: InstrumentationFieldFetchParameters
): DataFetcher<*> {
var manuallyCompletableDataFetcher: DataFetcher<*> = dataFetcher
executions.computeIfPresent(parameters.executionContext.executionInput) { _, executionState ->
executions.computeIfPresent(parameters.executionContext.executionInput.executionId) { _, executionState ->
executionState.also {
manuallyCompletableDataFetcher = it.toManuallyCompletableDataFetcher(
Level(parameters.executionStepInfo.path.level),
Expand All @@ -180,9 +201,11 @@ class ExecutionLevelDispatchedState(
* @param level that execution state will be calculated
* @return Boolean for allExecutionsDispatched statement
*/
fun allExecutionsDispatched(level: Level): Boolean =
executions
.takeIf { executions -> executions.size == totalExecutions }
?.all { (_, executionState) -> executionState.isLevelDispatched(level) }
?: false
fun allExecutionsDispatched(level: Level): Boolean = synchronized(executions) {
val operationsToExecute = totalExecutions.get()
when {
executions.size < operationsToExecute -> false
else -> executions.all { (_, executionState) -> executionState.isLevelDispatched(level) }
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2022 Expedia, Inc
* Copyright 2024 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@ import com.expediagroup.graphql.dataloader.instrumentation.syncexhaustion.execut
import com.expediagroup.graphql.dataloader.instrumentation.syncexhaustion.execution.SyncExecutionExhaustedInstrumentationParameters
import graphql.ExecutionInput
import graphql.GraphQLContext
import graphql.execution.ExecutionId
import graphql.execution.instrumentation.Instrumentation
import graphql.schema.DataFetcher
import org.dataloader.DataLoader
Expand All @@ -37,10 +38,10 @@ import java.util.concurrent.CompletableFuture
class DataLoaderSyncExecutionExhaustedInstrumentation : AbstractSyncExecutionExhaustedInstrumentation() {
override fun getOnSyncExecutionExhaustedCallback(
parameters: SyncExecutionExhaustedInstrumentationParameters
): OnSyncExecutionExhaustedCallback = { executions: List<ExecutionInput> ->
executions
.getOrNull(0)
?.dataLoaderRegistry
?.dispatchAll()
): OnSyncExecutionExhaustedCallback = { _: List<ExecutionId> ->
parameters
.executionContext.executionInput
.dataLoaderRegistry
.dispatchAll()
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2022 Expedia, Inc
* Copyright 2024 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,24 +22,26 @@ import graphql.ExecutionInput
import graphql.ExecutionResult
import graphql.GraphQLContext
import graphql.execution.ExecutionContext
import graphql.execution.ExecutionId
import graphql.execution.instrumentation.ExecutionStrategyInstrumentationContext
import graphql.execution.instrumentation.Instrumentation
import graphql.execution.instrumentation.InstrumentationContext
import graphql.execution.instrumentation.InstrumentationState
import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters
import graphql.execution.instrumentation.SimplePerformantInstrumentation
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters
import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters
import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters

/**
* typealias that represents the signature of a callback that will be executed when sync execution is exhausted
*/
internal typealias OnSyncExecutionExhaustedCallback = (List<ExecutionInput>) -> Unit
internal typealias OnSyncExecutionExhaustedCallback = (List<ExecutionId>) -> Unit

/**
* Custom GraphQL [Instrumentation] that calculate the synchronous execution exhaustion
* of all GraphQL operations sharing the same [GraphQLContext]
*/
abstract class AbstractSyncExecutionExhaustedInstrumentation : Instrumentation {
abstract class AbstractSyncExecutionExhaustedInstrumentation : SimplePerformantInstrumentation() {
/**
* This is invoked each time instrumentation attempts to calculate exhaustion state, this can be called from either
* `beginFieldField.dispatch` or `beginFieldFetch.complete`.
Expand All @@ -51,13 +53,13 @@ abstract class AbstractSyncExecutionExhaustedInstrumentation : Instrumentation {
parameters: SyncExecutionExhaustedInstrumentationParameters
): OnSyncExecutionExhaustedCallback

override fun beginExecuteOperation(
parameters: InstrumentationExecuteOperationParameters,
override fun beginExecution(
parameters: InstrumentationExecutionParameters,
state: InstrumentationState?
): InstrumentationContext<ExecutionResult>? =
parameters.executionContext.takeUnless(ExecutionContext::isMutation)
?.graphQLContext?.get<SyncExecutionExhaustedState>(SyncExecutionExhaustedState::class)
?.beginExecuteOperation(parameters)
parameters.graphQLContext
?.get<SyncExecutionExhaustedState>(SyncExecutionExhaustedState::class)
?.beginExecution(parameters)

override fun beginExecutionStrategy(
parameters: InstrumentationExecutionStrategyParameters,
Expand Down
Loading