Skip to content

Commit e22acf6

Browse files
committed
OperatorSwitch - fix lost requests race condition using ProducerArbiter
1 parent 9549363 commit e22acf6

File tree

3 files changed

+63
-107
lines changed

3 files changed

+63
-107
lines changed

src/main/java/rx/internal/operators/OperatorSwitch.java

Lines changed: 53 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import rx.Observable.Operator;
2323
import rx.Producer;
2424
import rx.Subscriber;
25+
import rx.internal.producers.ProducerArbiter;
2526
import rx.observers.SerializedSubscriber;
2627
import rx.subscriptions.SerialSubscription;
2728

@@ -46,7 +47,9 @@ private static final class Holder {
4647
public static <T> OperatorSwitch<T> instance() {
4748
return (OperatorSwitch<T>)Holder.INSTANCE;
4849
}
50+
4951
private OperatorSwitch() { }
52+
5053
@Override
5154
public Subscriber<? super Observable<? extends T>> call(final Subscriber<? super T> child) {
5255
SwitchSubscriber<T> sws = new SwitchSubscriber<T>(child);
@@ -55,10 +58,12 @@ public Subscriber<? super Observable<? extends T>> call(final Subscriber<? super
5558
}
5659

5760
private static final class SwitchSubscriber<T> extends Subscriber<Observable<? extends T>> {
58-
final SerializedSubscriber<T> s;
61+
final SerializedSubscriber<T> serializedChild;
5962
final SerialSubscription ssub;
6063
final Object guard = new Object();
6164
final NotificationLite<?> nl = NotificationLite.instance();
65+
final ProducerArbiter arbiter;
66+
6267
/** Guarded by guard. */
6368
int index;
6469
/** Guarded by guard. */
@@ -70,50 +75,19 @@ private static final class SwitchSubscriber<T> extends Subscriber<Observable<? e
7075
/** Guarded by guard. */
7176
boolean emitting;
7277
/** Guarded by guard. */
73-
InnerSubscriber currentSubscriber;
74-
/** Guarded by guard. */
75-
long initialRequested;
76-
77-
volatile boolean infinite = false;
78+
InnerSubscriber<T> currentSubscriber;
7879

79-
public SwitchSubscriber(Subscriber<? super T> child) {
80-
s = new SerializedSubscriber<T>(child);
80+
SwitchSubscriber(Subscriber<? super T> child) {
81+
serializedChild = new SerializedSubscriber<T>(child);
82+
arbiter = new ProducerArbiter();
8183
ssub = new SerialSubscription();
8284
child.add(ssub);
8385
child.setProducer(new Producer(){
8486

8587
@Override
8688
public void request(long n) {
87-
if (infinite) {
88-
return;
89-
}
90-
if(n == Long.MAX_VALUE) {
91-
infinite = true;
92-
}
93-
InnerSubscriber localSubscriber;
94-
synchronized (guard) {
95-
localSubscriber = currentSubscriber;
96-
if (currentSubscriber == null) {
97-
long r = initialRequested + n;
98-
if (r < 0) {
99-
infinite = true;
100-
} else {
101-
initialRequested = r;
102-
}
103-
} else {
104-
long r = currentSubscriber.requested + n;
105-
if (r < 0) {
106-
infinite = true;
107-
} else {
108-
currentSubscriber.requested = r;
109-
}
110-
}
111-
}
112-
if (localSubscriber != null) {
113-
if (infinite)
114-
localSubscriber.requestMore(Long.MAX_VALUE);
115-
else
116-
localSubscriber.requestMore(n);
89+
if (n > 0) {
90+
arbiter.request(n);
11791
}
11892
}
11993
});
@@ -122,26 +96,18 @@ public void request(long n) {
12296
@Override
12397
public void onNext(Observable<? extends T> t) {
12498
final int id;
125-
long remainingRequest;
12699
synchronized (guard) {
127100
id = ++index;
128101
active = true;
129-
if (infinite) {
130-
remainingRequest = Long.MAX_VALUE;
131-
} else {
132-
remainingRequest = currentSubscriber == null ? initialRequested : currentSubscriber.requested;
133-
}
134-
currentSubscriber = new InnerSubscriber(id, remainingRequest);
135-
currentSubscriber.requested = remainingRequest;
102+
currentSubscriber = new InnerSubscriber<T>(id, arbiter, this);
136103
}
137104
ssub.set(currentSubscriber);
138-
139105
t.unsafeSubscribe(currentSubscriber);
140106
}
141107

142108
@Override
143109
public void onError(Throwable e) {
144-
s.onError(e);
110+
serializedChild.onError(e);
145111
unsubscribe();
146112
}
147113

@@ -165,10 +131,10 @@ public void onCompleted() {
165131
emitting = true;
166132
}
167133
drain(localQueue);
168-
s.onCompleted();
134+
serializedChild.onCompleted();
169135
unsubscribe();
170136
}
171-
void emit(T value, int id, InnerSubscriber innerSubscriber) {
137+
void emit(T value, int id, InnerSubscriber<T> innerSubscriber) {
172138
List<Object> localQueue;
173139
synchronized (guard) {
174140
if (id != index) {
@@ -178,8 +144,6 @@ void emit(T value, int id, InnerSubscriber innerSubscriber) {
178144
if (queue == null) {
179145
queue = new ArrayList<Object>();
180146
}
181-
if (innerSubscriber.requested != Long.MAX_VALUE)
182-
innerSubscriber.requested--;
183147
queue.add(value);
184148
return;
185149
}
@@ -194,11 +158,8 @@ void emit(T value, int id, InnerSubscriber innerSubscriber) {
194158
drain(localQueue);
195159
if (once) {
196160
once = false;
197-
synchronized (guard) {
198-
if (innerSubscriber.requested != Long.MAX_VALUE)
199-
innerSubscriber.requested--;
200-
}
201-
s.onNext(value);
161+
serializedChild.onNext(value);
162+
arbiter.produced(1);
202163
}
203164
synchronized (guard) {
204165
localQueue = queue;
@@ -209,7 +170,7 @@ void emit(T value, int id, InnerSubscriber innerSubscriber) {
209170
break;
210171
}
211172
}
212-
} while (!s.isUnsubscribed());
173+
} while (!serializedChild.isUnsubscribed());
213174
} finally {
214175
if (!skipFinal) {
215176
synchronized (guard) {
@@ -224,16 +185,17 @@ void drain(List<Object> localQueue) {
224185
}
225186
for (Object o : localQueue) {
226187
if (nl.isCompleted(o)) {
227-
s.onCompleted();
188+
serializedChild.onCompleted();
228189
break;
229190
} else
230191
if (nl.isError(o)) {
231-
s.onError(nl.getError(o));
192+
serializedChild.onError(nl.getError(o));
232193
break;
233194
} else {
234195
@SuppressWarnings("unchecked")
235196
T t = (T)o;
236-
s.onNext(t);
197+
serializedChild.onNext(t);
198+
arbiter.produced(1);
237199
}
238200
}
239201
}
@@ -258,7 +220,7 @@ void error(Throwable e, int id) {
258220
}
259221

260222
drain(localQueue);
261-
s.onError(e);
223+
serializedChild.onError(e);
262224
unsubscribe();
263225
}
264226
void complete(int id) {
@@ -285,51 +247,45 @@ void complete(int id) {
285247
}
286248

287249
drain(localQueue);
288-
s.onCompleted();
250+
serializedChild.onCompleted();
289251
unsubscribe();
290252
}
291253

292-
final class InnerSubscriber extends Subscriber<T> {
293-
294-
/**
295-
* The number of request that is not acknowledged.
296-
*
297-
* Guarded by guard.
298-
*/
299-
private long requested = 0;
300-
301-
private final int id;
254+
}
255+
256+
private static final class InnerSubscriber<T> extends Subscriber<T> {
302257

303-
private final long initialRequested;
258+
private final int id;
304259

305-
public InnerSubscriber(int id, long initialRequested) {
306-
this.id = id;
307-
this.initialRequested = initialRequested;
308-
}
260+
private final ProducerArbiter arbiter;
309261

310-
@Override
311-
public void onStart() {
312-
requestMore(initialRequested);
313-
}
262+
private final SwitchSubscriber<T> parent;
314263

315-
public void requestMore(long n) {
316-
request(n);
317-
}
264+
InnerSubscriber(int id, ProducerArbiter arbiter, SwitchSubscriber<T> parent) {
265+
this.id = id;
266+
this.arbiter = arbiter;
267+
this.parent = parent;
268+
}
269+
270+
@Override
271+
public void setProducer(Producer p) {
272+
arbiter.setProducer(p);
273+
}
318274

319-
@Override
320-
public void onNext(T t) {
321-
emit(t, id, this);
322-
}
275+
@Override
276+
public void onNext(T t) {
277+
parent.emit(t, id, this);
278+
}
323279

324-
@Override
325-
public void onError(Throwable e) {
326-
error(e, id);
327-
}
280+
@Override
281+
public void onError(Throwable e) {
282+
parent.error(e, id);
283+
}
328284

329-
@Override
330-
public void onCompleted() {
331-
complete(id);
332-
}
285+
@Override
286+
public void onCompleted() {
287+
parent.complete(id);
333288
}
334289
}
290+
335291
}

src/test/java/rx/internal/operators/OperatorSwitchIfEmptyTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import rx.Observable;
2828
import rx.Observable.OnSubscribe;
2929
import rx.functions.Action0;
30-
import rx.functions.Action1;
3130
import rx.observers.TestSubscriber;
3231
import rx.schedulers.Schedulers;
3332
import rx.subscriptions.Subscriptions;

src/test/java/rx/internal/operators/OperatorSwitchTest.java

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import static org.mockito.Mockito.times;
2626
import static org.mockito.Mockito.verify;
2727

28-
import java.util.ArrayList;
2928
import java.util.Arrays;
3029
import java.util.List;
3130
import java.util.concurrent.CopyOnWriteArrayList;
@@ -642,32 +641,34 @@ public Observable<Long> call(Long t) {
642641
}
643642

644643
@Test(timeout = 10000)
645-
public void testSecondaryRequestsAdditivelyAreMoreThanLongMaxValueInducesMaxValueRequestFromUpstream() throws InterruptedException {
644+
public void testSecondaryRequestsAdditivelyAreMoreThanLongMaxValueInducesMaxValueRequestFromUpstream()
645+
throws InterruptedException {
646646
final List<Long> requests = new CopyOnWriteArrayList<Long>();
647647
final Action1<Long> addRequest = new Action1<Long>() {
648648

649649
@Override
650650
public void call(Long n) {
651651
requests.add(n);
652-
}};
653-
TestSubscriber<Long> ts = new TestSubscriber<Long>(0);
652+
}
653+
};
654+
TestSubscriber<Long> ts = new TestSubscriber<Long>(1);
654655
Observable.switchOnNext(
655656
Observable.interval(100, TimeUnit.MILLISECONDS)
656657
.map(new Func1<Long, Observable<Long>>() {
657658
@Override
658659
public Observable<Long> call(Long t) {
659-
return Observable.from(Arrays.asList(1L, 2L, 3L)).doOnRequest(addRequest);
660+
return Observable.from(Arrays.asList(1L, 2L, 3L)).doOnRequest(
661+
addRequest);
660662
}
661663
}).take(3)).subscribe(ts);
662-
ts.requestMore(1);
663-
//we will miss two of the first observable
664+
// we will miss two of the first observables
664665
Thread.sleep(250);
665666
ts.requestMore(Long.MAX_VALUE - 1);
666667
ts.requestMore(Long.MAX_VALUE - 1);
667668
ts.awaitTerminalEvent();
668669
assertTrue(ts.getOnNextEvents().size() > 0);
669670
assertEquals(5, (int) requests.size());
670-
assertEquals(Long.MAX_VALUE, (long) requests.get(3));
671-
assertEquals(Long.MAX_VALUE, (long) requests.get(4));
671+
assertEquals(Long.MAX_VALUE, (long) requests.get(requests.size()-1));
672672
}
673+
673674
}

0 commit comments

Comments
 (0)