Skip to content

Commit 989f1f1

Browse files
committed
Add a ServerInterceptor API for injecting custom CoroutineContext elements.
PiperOrigin-RevId: 313664182
1 parent f0e2628 commit 989f1f1

File tree

7 files changed

+176
-17
lines changed

7 files changed

+176
-17
lines changed

compiler/src/main/java/io/grpc/kotlin/generator/protoc/AbstractGeneratorRunner.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package io.grpc.kotlin.generator.protoc
1919
import com.google.common.annotations.VisibleForTesting
2020
import com.google.protobuf.DescriptorProtos.FileDescriptorSet
2121
import com.google.protobuf.Descriptors.FileDescriptor
22-
import com.google.protobuf.compiler.PluginProtos
22+
import com.google.protobuf.compiler.PluginProtosProtos
2323
import com.squareup.kotlinpoet.FileSpec
2424
import java.io.IOException
2525
import java.io.InputStream
@@ -34,7 +34,7 @@ abstract class AbstractGeneratorRunner {
3434
abstract fun generateCodeForFile(file: FileDescriptor): List<FileSpec>
3535

3636
@VisibleForTesting
37-
fun mainAsProtocPlugin(input: InputStream, output: OutputStream) {
37+
fun mainAsProtocPluginProtos(input: InputStream, output: OutputStream) {
3838
val generatorRequest = try {
3939
input.buffered().use {
4040
PluginProtos.CodeGeneratorRequest.parseFrom(it)
@@ -84,7 +84,7 @@ abstract class AbstractGeneratorRunner {
8484

8585
fun doMain(args: Array<String>) {
8686
if (args.isEmpty()) {
87-
mainAsProtocPlugin(System.`in`, System.out)
87+
mainAsProtocPluginProtos(System.`in`, System.out)
8888
} else {
8989
mainAsCommandLine(args, FileSystems.getDefault())
9090
}

compiler/src/main/java/io/grpc/kotlin/generator/protoc/CodeGenerators.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import com.google.common.base.Throwables
2020
import com.google.common.graph.GraphBuilder
2121
import com.google.protobuf.DescriptorProtos.FileDescriptorProto
2222
import com.google.protobuf.Descriptors.FileDescriptor
23-
import com.google.protobuf.compiler.PluginProtos
23+
import com.google.protobuf.compiler.PluginProtosProtos
2424
import com.squareup.kotlinpoet.FileSpec
2525
import io.grpc.kotlin.generator.protoc.util.graph.TopologicalSortGraph
2626

compiler/src/main/java/io/grpc/kotlin/generator/protoc/ProtoFileName.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package io.grpc.kotlin.generator.protoc
1818

1919
import com.google.protobuf.DescriptorProtos.FileDescriptorProto
2020
import com.google.protobuf.Descriptors.FileDescriptor
21-
import com.google.protobuf.compiler.PluginProtos
21+
import com.google.protobuf.compiler.PluginProtosProtos
2222

2323
/**
2424
* Represents the name of a proto file, relative to the root of the source tree, with
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package io.grpc.kotlin
2+
3+
import io.grpc.Metadata
4+
import io.grpc.ServerCall
5+
import io.grpc.ServerCallHandler
6+
import io.grpc.ServerInterceptor
7+
import kotlin.coroutines.CoroutineContext
8+
import kotlin.coroutines.EmptyCoroutineContext
9+
import io.grpc.Context as GrpcContext
10+
11+
/**
12+
* A [ServerInterceptor] subtype that can install elements in the [CoroutineContext] where server
13+
* logic is executed. These elements are applied "after" the
14+
* [AbstractCoroutineServerImpl.context]; that is, the interceptor overrides the server's context.
15+
*/
16+
abstract class CoroutineContextServerInterceptor : ServerInterceptor {
17+
companion object {
18+
// This is deliberately kept visibility-restricted; it's intentional that the only way to affect
19+
// the CoroutineContext is to extend CoroutineContextServerInterceptor.
20+
internal val COROUTINE_CONTEXT_KEY : GrpcContext.Key<CoroutineContext> =
21+
GrpcContext.keyWithDefault("grpc-kotlin-coroutine-context", EmptyCoroutineContext)
22+
23+
private fun GrpcContext.extendCoroutineContext(coroutineContext: CoroutineContext): GrpcContext {
24+
val oldCoroutineContext: CoroutineContext = COROUTINE_CONTEXT_KEY[this]
25+
val newCoroutineContext = oldCoroutineContext + coroutineContext
26+
return withValue(COROUTINE_CONTEXT_KEY, newCoroutineContext)
27+
}
28+
}
29+
30+
/**
31+
* Override this function to return a [CoroutineContext] in which to execute [call] and [headers].
32+
* The returned [CoroutineContext] will override any corresponding context elements in the
33+
* server object.
34+
*
35+
* This function will be called each time a [call] is executed.
36+
*/
37+
abstract fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext
38+
39+
private inline fun <R> withGrpcContext(context: GrpcContext, action: () -> R): R {
40+
val oldContext: GrpcContext = context.attach()
41+
return try {
42+
action()
43+
} finally {
44+
context.detach(oldContext)
45+
}
46+
}
47+
48+
final override fun <ReqT, RespT> interceptCall(
49+
call: ServerCall<ReqT, RespT>,
50+
headers: Metadata,
51+
next: ServerCallHandler<ReqT, RespT>
52+
): ServerCall.Listener<ReqT> =
53+
withGrpcContext(GrpcContext.current().extendCoroutineContext(coroutineContext(call, headers))) {
54+
next.startCall(call, headers)
55+
}
56+
}

stub/src/main/java/io/grpc/kotlin/ServerCalls.kt

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import io.grpc.ServerMethodDefinition
2727
import io.grpc.Status
2828
import io.grpc.StatusException
2929
import kotlinx.coroutines.CancellationException
30-
import kotlinx.coroutines.CoroutineName
3130
import kotlinx.coroutines.CoroutineScope
3231
import kotlinx.coroutines.async
3332
import kotlinx.coroutines.cancel
@@ -194,9 +193,11 @@ object ServerCalls {
194193
context: CoroutineContext,
195194
implementation: (Flow<RequestT>) -> Flow<ResponseT>
196195
): ServerCallHandler<RequestT, ResponseT> =
197-
ServerCallHandler {
198-
call, _ -> serverCallListener(
199-
context + GrpcContextElement.current(),
196+
ServerCallHandler { call, _ ->
197+
serverCallListener(
198+
context
199+
+ CoroutineContextServerInterceptor.COROUTINE_CONTEXT_KEY.get()
200+
+ GrpcContextElement.current(),
200201
call,
201202
implementation
202203
)
@@ -233,11 +234,9 @@ object ServerCalls {
233234
throw e
234235
}
235236
}
236-
237+
237238
val rpcScope = CoroutineScope(context)
238-
val rpcJob = rpcScope.async(
239-
CoroutineName("${call.methodDescriptor.fullMethodName} implementation")
240-
) {
239+
val rpcJob = rpcScope.async {
241240
runCatching {
242241
implementation(requests).collect {
243242
readiness.suspendUntilReady()
@@ -257,7 +256,7 @@ object ServerCalls {
257256
call.close(closeStatus, trailers)
258257
}
259258

260-
return object: ServerCall.Listener<RequestT>() {
259+
return object : ServerCall.Listener<RequestT>() {
261260
var isReceiving = true
262261

263262
override fun onCancel() {

stub/src/test/java/io/grpc/kotlin/AbstractCallsTest.kt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,10 @@ abstract class AbstractCallsTest {
149149
}
150150

151151
/** Generates a channel to a Greeter server with the specified implementation. */
152-
fun makeChannel(impl: BindableService): ManagedChannel {
152+
fun makeChannel(impl: BindableService, vararg interceptors: ServerInterceptor): ManagedChannel =
153+
makeChannel(ServerInterceptors.intercept(impl, *interceptors))
154+
155+
fun makeChannel(impl: ServerServiceDefinition): ManagedChannel {
153156
val serverName = InProcessServerBuilder.generateName()
154157

155158
grpcCleanup.register(
@@ -184,8 +187,9 @@ abstract class AbstractCallsTest {
184187
builder.addMethod(method, ServerCallHandler { _, _ -> TODO() })
185188
}
186189
}
187-
ServerInterceptors.intercept(builder.build(), *interceptors)
188-
}
190+
builder.build()
191+
},
192+
*interceptors
189193
)
190194

191195
fun <R> runBlocking(block: suspend CoroutineScope.() -> R): Unit =
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package io.grpc.kotlin
2+
3+
import com.google.common.truth.Truth.assertThat
4+
import io.grpc.ServerCall
5+
import io.grpc.ServerInterceptors
6+
import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineImplBase
7+
import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineStub
8+
import io.grpc.examples.helloworld.HelloReply
9+
import io.grpc.examples.helloworld.HelloRequest
10+
import org.junit.Test
11+
import org.junit.runner.RunWith
12+
import org.junit.runners.JUnit4
13+
import kotlin.coroutines.CoroutineContext
14+
import kotlin.coroutines.EmptyCoroutineContext
15+
import kotlin.coroutines.coroutineContext
16+
import io.grpc.Metadata as GrpcMetadata
17+
18+
/** Tests for [CoroutineContextServerInterceptor]. */
19+
@RunWith(JUnit4::class)
20+
class CoroutineContextServerInterceptorTest : AbstractCallsTest() {
21+
class ArbitraryContextElement(val message: String = "") : CoroutineContext.Element {
22+
companion object Key : CoroutineContext.Key<ArbitraryContextElement>
23+
override val key: CoroutineContext.Key<*>
24+
get() = Key
25+
}
26+
27+
class HelloReplyWithContextMessage(
28+
message: String? = null
29+
) : GreeterCoroutineImplBase(
30+
message?.let { ArbitraryContextElement(it) } ?: EmptyCoroutineContext
31+
) {
32+
override suspend fun sayHello(request: HelloRequest): HelloReply =
33+
helloReply(coroutineContext[ArbitraryContextElement]!!.message)
34+
}
35+
36+
@Test
37+
fun injectContext() {
38+
val interceptor = object : CoroutineContextServerInterceptor() {
39+
override fun coroutineContext(
40+
call: ServerCall<*, *>,
41+
headers: GrpcMetadata
42+
): CoroutineContext = ArbitraryContextElement("success")
43+
}
44+
45+
val channel = makeChannel(HelloReplyWithContextMessage(), interceptor)
46+
val client = GreeterCoroutineStub(channel)
47+
48+
runBlocking {
49+
assertThat(client.sayHello(helloRequest("")).message).isEqualTo("success")
50+
}
51+
}
52+
53+
@Test
54+
fun conflictingInterceptorsInnermostWins() {
55+
val interceptor1 = object : CoroutineContextServerInterceptor() {
56+
override fun coroutineContext(
57+
call: ServerCall<*, *>,
58+
headers: GrpcMetadata
59+
): CoroutineContext = ArbitraryContextElement("first")
60+
}
61+
val interceptor2 = object : CoroutineContextServerInterceptor() {
62+
override fun coroutineContext(
63+
call: ServerCall<*, *>,
64+
headers: GrpcMetadata
65+
): CoroutineContext = ArbitraryContextElement("second")
66+
}
67+
68+
val channel = makeChannel(
69+
ServerInterceptors.intercept(
70+
ServerInterceptors.intercept(
71+
HelloReplyWithContextMessage(),
72+
interceptor2
73+
),
74+
interceptor1
75+
)
76+
)
77+
val client = GreeterCoroutineStub(channel)
78+
79+
runBlocking {
80+
assertThat(client.sayHello(helloRequest("")).message).isEqualTo("second")
81+
}
82+
}
83+
84+
@Test
85+
fun interceptorContextTakesPriority() {
86+
val interceptor = object : CoroutineContextServerInterceptor() {
87+
override fun coroutineContext(
88+
call: ServerCall<*, *>,
89+
headers: GrpcMetadata
90+
): CoroutineContext = ArbitraryContextElement("interceptor")
91+
}
92+
93+
val channel = makeChannel(HelloReplyWithContextMessage("server"), interceptor)
94+
val client = GreeterCoroutineStub(channel)
95+
96+
runBlocking {
97+
assertThat(client.sayHello(helloRequest("")).message).isEqualTo("interceptor")
98+
}
99+
}
100+
}

0 commit comments

Comments
 (0)