Skip to content

Commit ef1b973

Browse files
authored
Address #66 by providing an interceptor-based API to inject CoroutineContexts. (#143)
* Revert "Add a ServerInterceptor API for injecting custom CoroutineContext elements." This reverts commit 989f1f1. * Add CoroutineContextServerInterceptor, an API to add elements to the CoroutineContext of the server.
1 parent ef5c808 commit ef1b973

File tree

3 files changed

+160
-4
lines changed

3 files changed

+160
-4
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package io.grpc.kotlin
2+
3+
import io.grpc.Metadata
4+
import io.grpc.ServerCall
5+
import io.grpc.ServerCallHandler
6+
import io.grpc.ServerInterceptor
7+
import kotlin.coroutines.CoroutineContext
8+
import kotlin.coroutines.EmptyCoroutineContext
9+
import io.grpc.Context as GrpcContext
10+
11+
/**
12+
* A [ServerInterceptor] subtype that can install elements in the [CoroutineContext] where server
13+
* logic is executed. These elements are applied "after" the
14+
* [AbstractCoroutineServerImpl.context]; that is, the interceptor overrides the server's context.
15+
*/
16+
abstract class CoroutineContextServerInterceptor : ServerInterceptor {
17+
companion object {
18+
// This is deliberately kept visibility-restricted; it's intentional that the only way to affect
19+
// the CoroutineContext is to extend CoroutineContextServerInterceptor.
20+
internal val COROUTINE_CONTEXT_KEY : GrpcContext.Key<CoroutineContext> =
21+
GrpcContext.keyWithDefault("grpc-kotlin-coroutine-context", EmptyCoroutineContext)
22+
23+
private fun GrpcContext.extendCoroutineContext(coroutineContext: CoroutineContext): GrpcContext {
24+
val oldCoroutineContext: CoroutineContext = COROUTINE_CONTEXT_KEY[this]
25+
val newCoroutineContext = oldCoroutineContext + coroutineContext
26+
return withValue(COROUTINE_CONTEXT_KEY, newCoroutineContext)
27+
}
28+
}
29+
30+
/**
31+
* Override this function to return a [CoroutineContext] in which to execute [call] and [headers].
32+
* The returned [CoroutineContext] will override any corresponding context elements in the
33+
* server object.
34+
*
35+
* This function will be called each time a [call] is executed.
36+
*/
37+
abstract fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext
38+
39+
private inline fun <R> withGrpcContext(context: GrpcContext, action: () -> R): R {
40+
val oldContext: GrpcContext = context.attach()
41+
return try {
42+
action()
43+
} finally {
44+
context.detach(oldContext)
45+
}
46+
}
47+
48+
final override fun <ReqT, RespT> interceptCall(
49+
call: ServerCall<ReqT, RespT>,
50+
headers: Metadata,
51+
next: ServerCallHandler<ReqT, RespT>
52+
): ServerCall.Listener<ReqT> =
53+
withGrpcContext(GrpcContext.current().extendCoroutineContext(coroutineContext(call, headers))) {
54+
next.startCall(call, headers)
55+
}
56+
}

stub/src/main/java/io/grpc/kotlin/ServerCalls.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ object ServerCalls {
196196
): ServerCallHandler<RequestT, ResponseT> =
197197
ServerCallHandler {
198198
call, _ -> serverCallListener(
199-
context + GrpcContextElement.current(),
199+
context
200+
+ CoroutineContextServerInterceptor.COROUTINE_CONTEXT_KEY.get()
201+
+ GrpcContextElement.current(),
200202
call,
201203
implementation
202204
)
@@ -235,9 +237,7 @@ object ServerCalls {
235237
}
236238

237239
val rpcScope = CoroutineScope(context)
238-
val rpcJob = rpcScope.async(
239-
CoroutineName("${call.methodDescriptor.fullMethodName} implementation")
240-
) {
240+
val rpcJob = rpcScope.async {
241241
runCatching {
242242
implementation(requests).collect {
243243
readiness.suspendUntilReady()
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package io.grpc.kotlin
2+
3+
import com.google.common.truth.Truth.assertThat
4+
import io.grpc.ServerCall
5+
import io.grpc.ServerInterceptors
6+
import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineImplBase
7+
import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineStub
8+
import io.grpc.examples.helloworld.HelloReply
9+
import io.grpc.examples.helloworld.HelloRequest
10+
import org.junit.Test
11+
import org.junit.runner.RunWith
12+
import org.junit.runners.JUnit4
13+
import kotlin.coroutines.CoroutineContext
14+
import kotlin.coroutines.EmptyCoroutineContext
15+
import kotlin.coroutines.coroutineContext
16+
import io.grpc.Metadata as GrpcMetadata
17+
18+
/** Tests for [CoroutineContextServerInterceptor]. */
19+
@RunWith(JUnit4::class) /* inserted by Copybara: */ @com.google.testing.testsize.MediumTest
20+
class CoroutineContextServerInterceptorTest : AbstractCallsTest() {
21+
class ArbitraryContextElement(val message: String = "") : CoroutineContext.Element {
22+
companion object Key : CoroutineContext.Key<ArbitraryContextElement>
23+
override val key: CoroutineContext.Key<*>
24+
get() = Key
25+
}
26+
27+
class HelloReplyWithContextMessage(
28+
message: String? = null
29+
) : GreeterCoroutineImplBase(
30+
message?.let { ArbitraryContextElement(it) } ?: EmptyCoroutineContext
31+
) {
32+
override suspend fun sayHello(request: HelloRequest): HelloReply =
33+
helloReply(coroutineContext[ArbitraryContextElement]!!.message)
34+
}
35+
36+
@Test
37+
fun injectContext() {
38+
val interceptor = object : CoroutineContextServerInterceptor() {
39+
override fun coroutineContext(
40+
call: ServerCall<*, *>,
41+
headers: GrpcMetadata
42+
): CoroutineContext = ArbitraryContextElement("success")
43+
}
44+
45+
val channel = makeChannel(HelloReplyWithContextMessage(), interceptor)
46+
val client = GreeterCoroutineStub(channel)
47+
48+
runBlocking {
49+
assertThat(client.sayHello(helloRequest("")).message).isEqualTo("success")
50+
}
51+
}
52+
53+
@Test
54+
fun conflictingInterceptorsInnermostWins() {
55+
val interceptor1 = object : CoroutineContextServerInterceptor() {
56+
override fun coroutineContext(
57+
call: ServerCall<*, *>,
58+
headers: GrpcMetadata
59+
): CoroutineContext = ArbitraryContextElement("first")
60+
}
61+
val interceptor2 = object : CoroutineContextServerInterceptor() {
62+
override fun coroutineContext(
63+
call: ServerCall<*, *>,
64+
headers: GrpcMetadata
65+
): CoroutineContext = ArbitraryContextElement("second")
66+
}
67+
68+
val channel = makeChannel(
69+
ServerInterceptors.intercept(
70+
ServerInterceptors.intercept(
71+
HelloReplyWithContextMessage(),
72+
interceptor2
73+
),
74+
interceptor1
75+
)
76+
)
77+
val client = GreeterCoroutineStub(channel)
78+
79+
runBlocking {
80+
assertThat(client.sayHello(helloRequest("")).message).isEqualTo("second")
81+
}
82+
}
83+
84+
@Test
85+
fun interceptorContextTakesPriority() {
86+
val interceptor = object : CoroutineContextServerInterceptor() {
87+
override fun coroutineContext(
88+
call: ServerCall<*, *>,
89+
headers: GrpcMetadata
90+
): CoroutineContext = ArbitraryContextElement("interceptor")
91+
}
92+
93+
val channel = makeChannel(HelloReplyWithContextMessage("server"), interceptor)
94+
val client = GreeterCoroutineStub(channel)
95+
96+
runBlocking {
97+
assertThat(client.sayHello(helloRequest("")).message).isEqualTo("interceptor")
98+
}
99+
}
100+
}

0 commit comments

Comments
 (0)