27
27
import org .reactivestreams .Subscriber ;
28
28
import org .reactivestreams .Subscription ;
29
29
import reactor .core .Disposable ;
30
- import reactor .core .publisher .*;
30
+ import reactor .core .publisher .Flux ;
31
+ import reactor .core .publisher .Mono ;
32
+ import reactor .core .publisher .SignalType ;
33
+ import reactor .core .publisher .UnicastProcessor ;
31
34
32
35
import java .util .Collections ;
33
36
import java .util .Map ;
39
42
import static io .rsocket .frame .FrameHeaderFlyweight .FLAGS_M ;
40
43
41
44
/** Server side RSocket. Receives {@link Frame}s from a {@link RSocketClient} */
42
- class RSocketServer implements RSocket {
45
+ class RSocketServer implements RequestHandler {
43
46
44
47
private final DuplexConnection connection ;
45
48
private final RSocket requestHandler ;
49
+ private final RequestHandler optimizedRequestHandler ;
50
+ private final boolean hasOptimizedRequestHandler ;
46
51
private final Function <Frame , ? extends Payload > frameDecoder ;
47
52
private final Consumer <Throwable > errorConsumer ;
48
53
49
54
private final Map <Integer , Subscription > sendingSubscriptions ;
50
- private final Map <Integer , Processor <Payload ,Payload >> channelProcessors ;
55
+ private final Map <Integer , Processor <Payload , Payload >> channelProcessors ;
51
56
52
57
private final UnboundedProcessor <Frame > sendProcessor ;
53
58
private KeepAliveHandler keepAliveHandler ;
@@ -69,12 +74,22 @@ class RSocketServer implements RSocket {
69
74
Consumer <Throwable > errorConsumer ,
70
75
long tickPeriod ,
71
76
long ackTimeout ) {
77
+
78
+ if (requestHandler instanceof RequestHandler ) {
79
+ this .optimizedRequestHandler = (RequestHandler ) requestHandler ;
80
+ this .hasOptimizedRequestHandler = true ;
81
+ this .requestHandler = null ;
82
+ } else {
83
+ this .hasOptimizedRequestHandler = false ;
84
+ this .requestHandler = requestHandler ;
85
+ this .optimizedRequestHandler = null ;
86
+ }
87
+
72
88
this .connection = connection ;
73
- this .requestHandler = requestHandler ;
74
89
this .frameDecoder = frameDecoder ;
75
90
this .errorConsumer = errorConsumer ;
76
91
this .sendingSubscriptions = Collections .synchronizedMap (new IntObjectHashMap <>());
77
- this .channelProcessors = Collections .synchronizedMap (new IntObjectHashMap <>());
92
+ this .channelProcessors = Collections .synchronizedMap (new IntObjectHashMap <>());
78
93
79
94
// DO NOT Change the order here. The Send processor must be subscribed to before receiving
80
95
// connections
@@ -116,43 +131,55 @@ class RSocketServer implements RSocket {
116
131
}
117
132
118
133
private void handleSendProcessorError (Throwable t ) {
119
- sendingSubscriptions .values ().forEach (subscription -> {
120
- try {
121
- subscription .cancel ();
122
- } catch (Throwable e ) {
123
- errorConsumer .accept (e );
124
- }
125
- });
134
+ sendingSubscriptions
135
+ .values ()
136
+ .forEach (
137
+ subscription -> {
138
+ try {
139
+ subscription .cancel ();
140
+ } catch (Throwable e ) {
141
+ errorConsumer .accept (e );
142
+ }
143
+ });
126
144
127
- channelProcessors .values ().forEach (subscription -> {
128
- try {
129
- subscription .onError (t );
130
- } catch (Throwable e ) {
131
- errorConsumer .accept (e );
132
- }
133
- });
145
+ channelProcessors
146
+ .values ()
147
+ .forEach (
148
+ subscription -> {
149
+ try {
150
+ subscription .onError (t );
151
+ } catch (Throwable e ) {
152
+ errorConsumer .accept (e );
153
+ }
154
+ });
134
155
}
135
156
136
157
private void handleSendProcessorCancel (SignalType t ) {
137
158
if (SignalType .ON_ERROR == t ) {
138
159
return ;
139
160
}
140
161
141
- sendingSubscriptions .values ().forEach (subscription -> {
142
- try {
143
- subscription .cancel ();
144
- } catch (Throwable e ) {
145
- errorConsumer .accept (e );
146
- }
147
- });
162
+ sendingSubscriptions
163
+ .values ()
164
+ .forEach (
165
+ subscription -> {
166
+ try {
167
+ subscription .cancel ();
168
+ } catch (Throwable e ) {
169
+ errorConsumer .accept (e );
170
+ }
171
+ });
148
172
149
- channelProcessors .values ().forEach (subscription -> {
150
- try {
151
- subscription .onComplete ();
152
- } catch (Throwable e ) {
153
- errorConsumer .accept (e );
154
- }
155
- });
173
+ channelProcessors
174
+ .values ()
175
+ .forEach (
176
+ subscription -> {
177
+ try {
178
+ subscription .onComplete ();
179
+ } catch (Throwable e ) {
180
+ errorConsumer .accept (e );
181
+ }
182
+ });
156
183
}
157
184
158
185
@ Override
@@ -191,6 +218,15 @@ public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
191
218
}
192
219
}
193
220
221
+ @ Override
222
+ public Flux <Payload > requestChannel (Payload payload , Publisher <Payload > payloads ) {
223
+ try {
224
+ return optimizedRequestHandler .requestChannel (payloads );
225
+ } catch (Throwable t ) {
226
+ return Flux .error (t );
227
+ }
228
+ }
229
+
194
230
@ Override
195
231
public Mono <Void > metadataPush (Payload payload ) {
196
232
try {
@@ -232,9 +268,7 @@ private synchronized void cleanUpSendingSubscriptions() {
232
268
}
233
269
234
270
private synchronized void cleanUpChannelProcessors () {
235
- channelProcessors
236
- .values ()
237
- .forEach (Processor ::onComplete );
271
+ channelProcessors .values ().forEach (Processor ::onComplete );
238
272
channelProcessors .clear ();
239
273
}
240
274
@@ -381,7 +415,11 @@ private void handleChannel(int streamId, Payload payload, int initialRequestN) {
381
415
// and any later payload can be processed
382
416
frames .onNext (payload );
383
417
384
- handleStream (streamId , requestChannel (payloads ), initialRequestN );
418
+ if (hasOptimizedRequestHandler ) {
419
+ handleStream (streamId , requestChannel (payload , payloads ), initialRequestN );
420
+ } else {
421
+ handleStream (streamId , requestChannel (payloads ), initialRequestN );
422
+ }
385
423
}
386
424
387
425
private void handleKeepAliveFrame (Frame frame ) {
0 commit comments