|
16 | 16 |
|
17 | 17 | package io.rsocket.core;
|
18 | 18 |
|
19 |
| -import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; |
20 |
| -import static io.rsocket.frame.FrameHeaderFlyweight.frameType; |
21 |
| -import static io.rsocket.frame.FrameType.CANCEL; |
22 |
| -import static io.rsocket.frame.FrameType.KEEPALIVE; |
23 |
| -import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; |
24 |
| -import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; |
25 |
| -import static io.rsocket.frame.FrameType.REQUEST_STREAM; |
26 |
| -import static org.hamcrest.MatcherAssert.assertThat; |
27 |
| -import static org.hamcrest.Matchers.contains; |
28 |
| -import static org.hamcrest.Matchers.equalTo; |
29 |
| -import static org.hamcrest.Matchers.greaterThanOrEqualTo; |
30 |
| -import static org.hamcrest.Matchers.hasSize; |
31 |
| -import static org.hamcrest.Matchers.instanceOf; |
32 |
| -import static org.hamcrest.Matchers.is; |
33 |
| -import static org.hamcrest.Matchers.not; |
34 |
| -import static org.mockito.ArgumentMatchers.any; |
35 |
| -import static org.mockito.Mockito.verify; |
36 |
| - |
37 | 19 | import io.netty.buffer.ByteBuf;
|
38 | 20 | import io.netty.buffer.ByteBufAllocator;
|
39 | 21 | import io.netty.util.CharsetUtil;
|
| 22 | +import io.netty.util.ReferenceCountUtil; |
| 23 | +import io.netty.util.ReferenceCounted; |
40 | 24 | import io.rsocket.Payload;
|
41 | 25 | import io.rsocket.RSocket;
|
| 26 | +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; |
42 | 27 | import io.rsocket.exceptions.ApplicationErrorException;
|
43 | 28 | import io.rsocket.exceptions.RejectedSetupException;
|
44 | 29 | import io.rsocket.frame.CancelFrameFlyweight;
|
|
50 | 35 | import io.rsocket.frame.RequestChannelFrameFlyweight;
|
51 | 36 | import io.rsocket.frame.RequestNFrameFlyweight;
|
52 | 37 | import io.rsocket.frame.RequestStreamFrameFlyweight;
|
| 38 | +import io.rsocket.frame.decoder.PayloadDecoder; |
| 39 | +import io.rsocket.internal.subscriber.AssertSubscriber; |
53 | 40 | import io.rsocket.lease.RequesterLeaseHandler;
|
54 | 41 | import io.rsocket.test.util.TestSubscriber;
|
| 42 | +import io.rsocket.util.ByteBufPayload; |
55 | 43 | import io.rsocket.util.DefaultPayload;
|
56 | 44 | import io.rsocket.util.EmptyPayload;
|
57 | 45 | import io.rsocket.util.MultiSubscriberRSocket;
|
58 |
| -import java.time.Duration; |
59 |
| -import java.util.ArrayList; |
60 |
| -import java.util.Iterator; |
61 |
| -import java.util.List; |
62 |
| -import java.util.concurrent.ThreadLocalRandom; |
63 |
| -import java.util.function.BiFunction; |
64 |
| -import java.util.stream.Collectors; |
65 |
| -import java.util.stream.Stream; |
66 | 46 | import org.assertj.core.api.Assertions;
|
67 | 47 | import org.junit.Rule;
|
68 | 48 | import org.junit.Test;
|
| 49 | +import org.junit.jupiter.params.provider.Arguments; |
| 50 | +import org.junit.runners.model.Statement; |
69 | 51 | import org.reactivestreams.Publisher;
|
70 | 52 | import org.reactivestreams.Subscriber;
|
71 | 53 | import org.reactivestreams.Subscription;
|
72 | 54 | import reactor.core.publisher.BaseSubscriber;
|
73 | 55 | import reactor.core.publisher.Flux;
|
| 56 | +import reactor.core.publisher.Hooks; |
74 | 57 | import reactor.core.publisher.Mono;
|
75 | 58 | import reactor.core.publisher.MonoProcessor;
|
76 | 59 | import reactor.core.publisher.UnicastProcessor;
|
77 | 60 | import reactor.test.StepVerifier;
|
| 61 | +import reactor.test.util.RaceTestUtils; |
| 62 | + |
| 63 | +import java.time.Duration; |
| 64 | +import java.util.ArrayList; |
| 65 | +import java.util.Iterator; |
| 66 | +import java.util.List; |
| 67 | +import java.util.concurrent.ThreadLocalRandom; |
| 68 | +import java.util.function.BiConsumer; |
| 69 | +import java.util.function.BiFunction; |
| 70 | +import java.util.function.Function; |
| 71 | +import java.util.stream.Collectors; |
| 72 | +import java.util.stream.Stream; |
| 73 | + |
| 74 | +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; |
| 75 | +import static io.rsocket.frame.FrameHeaderFlyweight.frameType; |
| 76 | +import static io.rsocket.frame.FrameType.CANCEL; |
| 77 | +import static io.rsocket.frame.FrameType.KEEPALIVE; |
| 78 | +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; |
| 79 | +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; |
| 80 | +import static io.rsocket.frame.FrameType.REQUEST_STREAM; |
| 81 | +import static org.hamcrest.MatcherAssert.assertThat; |
| 82 | +import static org.hamcrest.Matchers.contains; |
| 83 | +import static org.hamcrest.Matchers.equalTo; |
| 84 | +import static org.hamcrest.Matchers.greaterThanOrEqualTo; |
| 85 | +import static org.hamcrest.Matchers.hasSize; |
| 86 | +import static org.hamcrest.Matchers.instanceOf; |
| 87 | +import static org.hamcrest.Matchers.is; |
| 88 | +import static org.hamcrest.Matchers.not; |
| 89 | +import static org.mockito.ArgumentMatchers.any; |
| 90 | +import static org.mockito.Mockito.verify; |
78 | 91 |
|
79 | 92 | public class RSocketRequesterTest {
|
80 | 93 |
|
@@ -333,6 +346,124 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen
|
333 | 346 | .verify();
|
334 | 347 | }
|
335 | 348 |
|
| 349 | + |
| 350 | + private static Stream<Arguments> racingCases() { |
| 351 | + return Stream.of( |
| 352 | + Arguments.of( |
| 353 | + (Runnable) () -> System.out.println("RequestChannel downstream cancellation case"), |
| 354 | + (Function<ClientSocketRule, Publisher<Payload>>) (rule) -> rule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE)), |
| 355 | + (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>) (as, rule) -> { |
| 356 | + LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); |
| 357 | + ByteBuf metadata = allocator.buffer(); |
| 358 | + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); |
| 359 | + ByteBuf data = allocator.buffer(); |
| 360 | + data.writeCharSequence("def", CharsetUtil.UTF_8); |
| 361 | + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); |
| 362 | + ByteBuf frame = PayloadFrameFlyweight.encode(allocator, streamId, false, false, true, metadata, data); |
| 363 | + |
| 364 | + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); |
| 365 | + } |
| 366 | + )/*,*/ |
| 367 | +// Arguments.of( |
| 368 | +// (Runnable) () -> System.out.println("RequestChannel upstream cancellation 1"), |
| 369 | +// (Function<ClientSocketRule, Publisher<Payload>>) (rule) -> { |
| 370 | +// LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); |
| 371 | +// ByteBuf metadata = allocator.buffer(); |
| 372 | +// metadata.writeCharSequence("abc", CharsetUtil.UTF_8); |
| 373 | +// ByteBuf data = allocator.buffer(); |
| 374 | +// data.writeCharSequence("def", CharsetUtil.UTF_8); |
| 375 | +// return rule.socket.requestChannel(Flux.just(ByteBufPayload.create(data, metadata))); |
| 376 | +// }, |
| 377 | +// (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>) (as, rule) -> { |
| 378 | +// LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); |
| 379 | +// int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); |
| 380 | +// ByteBuf frame = CancelFrameFlyweight.encode(allocator, streamId); |
| 381 | +// |
| 382 | +// RaceTestUtils.race(() -> as.request(1), () -> rule.connection.addToReceivedBuffer(frame)); |
| 383 | +// } |
| 384 | +// ), |
| 385 | +// Arguments.of( |
| 386 | +// (Runnable) () -> System.out.println("RequestChannel upstream cancellation 2"), |
| 387 | +// (Function<ClientSocketRule, Publisher<Payload>>) (rule) -> { |
| 388 | +// return rule.socket.requestChannel(Flux.just(ByteBufPayload.create("a", "b"), ByteBufPayload.create("c", "d"), ByteBufPayload.create("e", "f"))); |
| 389 | +// }, |
| 390 | +// (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>) (as, rule) -> { |
| 391 | +// LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); |
| 392 | +// int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); |
| 393 | +// ByteBuf frame = CancelFrameFlyweight.encode(allocator, streamId); |
| 394 | +// |
| 395 | +// as.request(1); |
| 396 | +// |
| 397 | +// RaceTestUtils.race(() -> as.request(Long.MAX_VALUE), () -> rule.connection.addToReceivedBuffer(frame)); |
| 398 | +// } |
| 399 | +// ), |
| 400 | +// Arguments.of( |
| 401 | +// (Runnable) () -> System.out.println("RequestResponse downstream cancellation"), |
| 402 | +// (Function<ClientSocketRule, Publisher<Payload>>) (rule) -> rule.socket.requestResponse(EmptyPayload.INSTANCE), |
| 403 | +// (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>) (as, rule) -> { |
| 404 | +// LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); |
| 405 | +// ByteBuf metadata = allocator.buffer(); |
| 406 | +// metadata.writeCharSequence("abc", CharsetUtil.UTF_8); |
| 407 | +// ByteBuf data = allocator.buffer(); |
| 408 | +// data.writeCharSequence("def", CharsetUtil.UTF_8); |
| 409 | +// int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); |
| 410 | +// ByteBuf frame = PayloadFrameFlyweight.encode(allocator, streamId, false, false, true, metadata, data); |
| 411 | +// |
| 412 | +// RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); |
| 413 | +// } |
| 414 | +// ) |
| 415 | + ); |
| 416 | + } |
| 417 | + |
| 418 | + @Test |
| 419 | + @SuppressWarnings("unchecked") |
| 420 | + public void checkNoLeaksOnRacingTest() { |
| 421 | + |
| 422 | + racingCases() |
| 423 | + .forEach(a -> { |
| 424 | + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); |
| 425 | + LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrumentDefault(); |
| 426 | + ((Runnable)a.get()[0]).run(); |
| 427 | + checkNoLeaksOnRacing(allocator, (Function<ClientSocketRule, Publisher<Payload>>) a.get()[1], (BiConsumer<AssertSubscriber<Payload>, ClientSocketRule>) a.get()[2]); |
| 428 | + |
| 429 | + Hooks.resetOnNextDropped(); |
| 430 | + LeaksTrackingByteBufAllocator.deinstrumentDefault(); |
| 431 | + }); |
| 432 | + } |
| 433 | + |
| 434 | + public void checkNoLeaksOnRacing(LeaksTrackingByteBufAllocator allocator, Function<ClientSocketRule, Publisher<Payload>> initiator, BiConsumer<AssertSubscriber<Payload>, ClientSocketRule> runner) { |
| 435 | + for (int i = 0; i < 100000; i++) { |
| 436 | + System.out.println(i); |
| 437 | + ClientSocketRule clientSocketRule = new ClientSocketRule(); |
| 438 | + try { |
| 439 | + clientSocketRule.apply(new Statement() { |
| 440 | + @Override |
| 441 | + public void evaluate() throws Throwable { |
| 442 | + |
| 443 | + } |
| 444 | + }, null).evaluate(); |
| 445 | + } catch (Throwable throwable) { |
| 446 | + throwable.printStackTrace(); |
| 447 | + } |
| 448 | + |
| 449 | + Publisher<Payload> payloadP = initiator.apply(clientSocketRule); |
| 450 | + AssertSubscriber<Payload> assertSubscriber = AssertSubscriber.create(); |
| 451 | + |
| 452 | + if (payloadP instanceof Flux) { |
| 453 | + ((Flux<Payload>)payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); |
| 454 | + } else { |
| 455 | + ((Mono<Payload>)payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); |
| 456 | + } |
| 457 | + |
| 458 | + runner.accept(assertSubscriber, clientSocketRule); |
| 459 | + |
| 460 | + Assertions.assertThat(clientSocketRule.connection.getSent()) |
| 461 | + .allMatch(ReferenceCounted::release); |
| 462 | + |
| 463 | + allocator.assertHasNoLeaks(); |
| 464 | + } |
| 465 | + } |
| 466 | + |
336 | 467 | static Stream<BiFunction<RSocket, Payload, Publisher<?>>> prepareCalls() {
|
337 | 468 | return Stream.of(
|
338 | 469 | RSocket::fireAndForget,
|
@@ -360,7 +491,7 @@ protected RSocketRequester newRSocket() {
|
360 | 491 | return new RSocketRequester(
|
361 | 492 | ByteBufAllocator.DEFAULT,
|
362 | 493 | connection,
|
363 |
| - DefaultPayload::create, |
| 494 | + PayloadDecoder.ZERO_COPY, |
364 | 495 | throwable -> errors.add(throwable),
|
365 | 496 | StreamIdSupplier.clientSupplier(),
|
366 | 497 | 0,
|
|
0 commit comments