Skip to content

Support Synchronous Source in OnSubscribeRefCount #1753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 93 additions & 165 deletions src/main/java/rx/internal/operators/OnSubscribeRefCount.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,195 +15,123 @@
*/
package rx.internal.operators;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;

import rx.Observable.OnSubscribe;
import rx.Subscriber;
import rx.Subscription;
import rx.functions.Action0;
import rx.functions.Action1;
import rx.observables.ConnectableObservable;
import rx.subscriptions.CompositeSubscription;
import rx.subscriptions.Subscriptions;

/**
* Returns an observable sequence that stays connected to the source as long
* as there is at least one subscription to the observable sequence.
* @param <T> the value type
* Returns an observable sequence that stays connected to the source as long as
* there is at least one subscription to the observable sequence.
*
* @param <T>
* the value type
*/
public final class OnSubscribeRefCount<T> implements OnSubscribe<T> {
final ConnectableObservable<? extends T> source;
final Object guard;
/** Guarded by guard. */
int index;
/** Guarded by guard. */
boolean emitting;
/** Guarded by guard. If true, indicates a connection request, false indicates a disconnect request. */
List<Token> queue;
/** Manipulated while in the serialized section. */
int count;
/** Manipulated while in the serialized section. */
Subscription connection;
/** Manipulated while in the serialized section. */
final Map<Token, Object> connectionStatus;
/** Occupied indicator. */
private static final Object OCCUPIED = new Object();

private final ConnectableObservable<? extends T> source;
private volatile CompositeSubscription baseSubscription = new CompositeSubscription();
private final AtomicInteger subscriptionCount = new AtomicInteger(0);

/**
* Use this lock for every subscription and disconnect action.
*/
private final ReentrantLock lock = new ReentrantLock();

/**
* Constructor.
*
* @param source
* observable to apply ref count to
*/
public OnSubscribeRefCount(ConnectableObservable<? extends T> source) {
this.source = source;
this.guard = new Object();
this.connectionStatus = new WeakHashMap<Token, Object>();
}

@Override
public void call(Subscriber<? super T> t1) {
int id;
synchronized (guard) {
id = ++index;
}
final Token t = new Token(id);
t1.add(Subscriptions.create(new Action0() {
@Override
public void call() {
disconnect(t);
}
}));
source.unsafeSubscribe(t1);
connect(t);
}
private void connect(Token id) {
List<Token> localQueue;
synchronized (guard) {
if (emitting) {
if (queue == null) {
queue = new ArrayList<Token>();
}
queue.add(id);
return;
}

localQueue = queue;
queue = null;
emitting = true;
}
boolean once = true;
do {
drain(localQueue);
if (once) {
once = false;
doConnect(id);
}
synchronized (guard) {
localQueue = queue;
queue = null;
if (localQueue == null) {
emitting = false;
return;
}
}
} while (true);
}
private void disconnect(Token id) {
List<Token> localQueue;
synchronized (guard) {
if (emitting) {
if (queue == null) {
queue = new ArrayList<Token>();
}
queue.add(id.toDisconnect()); // negative value indicates disconnect
return;
}

localQueue = queue;
queue = null;
emitting = true;
}
boolean once = true;
do {
drain(localQueue);
if (once) {
once = false;
doDisconnect(id);
}
synchronized (guard) {
localQueue = queue;
queue = null;
if (localQueue == null) {
emitting = false;
return;
public void call(final Subscriber<? super T> subscriber) {

lock.lock();
if (subscriptionCount.incrementAndGet() == 1) {

final AtomicBoolean writeLocked = new AtomicBoolean(true);

try {
// need to use this overload of connect to ensure that
// baseSubscription is set in the case that source is a
// synchronous Observable
source.connect(onSubscribe(subscriber, writeLocked));
} finally {
// need to cover the case where the source is subscribed to
// outside of this class thus preventing the above Action1
// being called
if (writeLocked.get()) {
// Action1 was not called
lock.unlock();
}
}
} while (true);
}
private void drain(List<Token> localQueue) {
if (localQueue == null) {
return;
}
int n = localQueue.size();
for (int i = 0; i < n; i++) {
Token id = localQueue.get(i);
if (id.isDisconnect()) {
doDisconnect(id);
} else {
doConnect(id);
}
}
}
private void doConnect(Token id) {
// this method is called only once per id
// if add succeeds, id was not yet disconnected
if (connectionStatus.put(id, OCCUPIED) == null) {
if (count++ == 0) {
connection = source.connect();
}
} else {
// connection exists due to disconnect, just remove
connectionStatus.remove(id);
}
}
private void doDisconnect(Token id) {
// this method is called only once per id
// if remove succeeds, id was connected
if (connectionStatus.remove(id) != null) {
if (--count == 0) {
connection.unsubscribe();
connection = null;
try {
// handle unsubscribing from the base subscription
subscriber.add(disconnect());

// ready to subscribe to source so do it
source.unsafeSubscribe(subscriber);
} finally {
// release the read lock
lock.unlock();
}
} else {
// mark id as if connected
connectionStatus.put(id, OCCUPIED);
}

}
/** Token that represens a connection request or a disconnection request. */
private static final class Token {
final int id;
public Token(int id) {
this.id = id;
}

@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (obj.getClass() != getClass()) {
return false;
private Action1<Subscription> onSubscribe(final Subscriber<? super T> subscriber,
final AtomicBoolean writeLocked) {
return new Action1<Subscription>() {
@Override
public void call(Subscription subscription) {

try {
baseSubscription.add(subscription);

// handle unsubscribing from the base subscription
subscriber.add(disconnect());

// ready to subscribe to source so do it
source.unsafeSubscribe(subscriber);
} finally {
// release the write lock
lock.unlock();
writeLocked.set(false);
}
}
int other = ((Token)obj).id;
return id == other || -id == other;
}
};
}

@Override
public int hashCode() {
return id < 0 ? -id : id;
}
public boolean isDisconnect() {
return id < 0;
}
public Token toDisconnect() {
if (id < 0) {
return this;
private Subscription disconnect() {
return Subscriptions.create(new Action0() {
@Override
public void call() {
lock.lock();
try {
if (subscriptionCount.decrementAndGet() == 0) {
baseSubscription.unsubscribe();
// need a new baseSubscription because once
// unsubscribed stays that way
baseSubscription = new CompositeSubscription();
}
} finally {
lock.unlock();
}
}
return new Token(-id);
}
});
}
}
8 changes: 7 additions & 1 deletion src/main/java/rx/internal/operators/OperatorMulticast.java
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,13 @@ public void call() {
}));

// now that everything is hooked up let's subscribe
source.unsafeSubscribe(subscription);
// as long as the subscription is not null
boolean subscriptionIsNull;
synchronized(guard) {
subscriptionIsNull = subscription == null;
}
if (!subscriptionIsNull)
source.unsafeSubscribe(subscription);
}
}
}
50 changes: 50 additions & 0 deletions src/test/java/rx/RefCountTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package rx;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
Expand All @@ -25,6 +26,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

Expand All @@ -34,11 +36,14 @@
import org.mockito.MockitoAnnotations;

import rx.Observable.OnSubscribe;
import rx.Observable.Operator;
import rx.functions.Action0;
import rx.functions.Action1;
import rx.functions.Func2;
import rx.observables.ConnectableObservable;
import rx.observers.Subscribers;
import rx.observers.TestSubscriber;
import rx.schedulers.Schedulers;
import rx.schedulers.TestScheduler;
import rx.subjects.ReplaySubject;
import rx.subscriptions.Subscriptions;
Expand Down Expand Up @@ -237,4 +242,49 @@ public Integer call(Integer t1, Integer t2) {
ts2.assertNoErrors();
ts2.assertReceivedOnNext(Arrays.asList(30));
}

@Test
public void testRefCountUnsubscribeForSynchronousSource() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
Observable<Long> o = synchronousInterval().lift(detectUnsubscription(latch));
Subscriber<Long> sub = Subscribers.empty();
o.publish().refCount().subscribeOn(Schedulers.computation()).subscribe(sub);
Thread.sleep(100);
sub.unsubscribe();
assertTrue(latch.await(3, TimeUnit.SECONDS));
}

@Test
public void testSubscribeToPublishWithAlreadyUnsubscribedSubscriber() {
Subscriber<Object> sub = Subscribers.empty();
sub.unsubscribe();
ConnectableObservable<Object> o = Observable.empty().publish();
o.subscribe(sub);
o.connect();
}

private Operator<Long, Long> detectUnsubscription(final CountDownLatch latch) {
return new Operator<Long,Long>(){
@Override
public Subscriber<? super Long> call(Subscriber<? super Long> subscriber) {
latch.countDown();
return Subscribers.from(subscriber);
}};
}

private Observable<Long> synchronousInterval() {
return Observable.create(new OnSubscribe<Long>() {

@Override
public void call(Subscriber<? super Long> subscriber) {
while (!subscriber.isUnsubscribed()) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
}
subscriber.onNext(1L);
}
}});
}

}