Skip to content

Commit 644806b

Browse files
Fix PublishSubject non-deterministic behavior on concurrent modification
- changed to take snapshot of observers.values() before iterating in onNext/onError/onCompleted so that nested subscriptions that add to observers can't change the values() iteration - single-threaded nested subscriptions are now deterministic - multi-threaded subscriptions will no longer be allowed to race to get into an interating onNext/onError/onCompleted loop, they will always wait until the next - also improved terminal state behavior when subscribing to a PublishSubject that has already received onError/onCompleted ReactiveX#282
1 parent 52640f5 commit 644806b

File tree

1 file changed

+193
-6
lines changed

1 file changed

+193
-6
lines changed

rxjava-core/src/main/java/rx/subjects/PublishSubject.java

Lines changed: 193 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,30 @@
1515
*/
1616
package rx.subjects;
1717

18+
import static org.junit.Assert.*;
1819
import static org.mockito.Matchers.*;
1920
import static org.mockito.Mockito.*;
2021

2122
import java.util.ArrayList;
23+
import java.util.Collection;
2224
import java.util.List;
2325
import java.util.concurrent.ConcurrentHashMap;
2426
import java.util.concurrent.atomic.AtomicBoolean;
27+
import java.util.concurrent.atomic.AtomicInteger;
2528
import java.util.concurrent.atomic.AtomicReference;
2629

2730
import junit.framework.Assert;
2831

2932
import org.junit.Test;
33+
import org.mockito.InOrder;
3034
import org.mockito.Mockito;
3135

3236
import rx.Notification;
3337
import rx.Observable;
3438
import rx.Observer;
3539
import rx.Subscription;
3640
import rx.operators.AtomicObservableSubscription;
41+
import rx.subscriptions.Subscriptions;
3742
import rx.util.functions.Action1;
3843
import rx.util.functions.Func0;
3944
import rx.util.functions.Func1;
@@ -62,10 +67,15 @@
6267
public class PublishSubject<T> extends Subject<T, T> {
6368
public static <T> PublishSubject<T> create() {
6469
final ConcurrentHashMap<Subscription, Observer<T>> observers = new ConcurrentHashMap<Subscription, Observer<T>>();
65-
70+
final AtomicReference<Notification<T>> terminalState = new AtomicReference<Notification<T>>();
71+
6672
Func1<Observer<T>, Subscription> onSubscribe = new Func1<Observer<T>, Subscription>() {
6773
@Override
6874
public Subscription call(Observer<T> observer) {
75+
// first check if terminal state exist
76+
Subscription s = checkTerminalState(observer);
77+
if(s != null) return s;
78+
6979
final AtomicObservableSubscription subscription = new AtomicObservableSubscription();
7080

7181
subscription.wrap(new Subscription() {
@@ -78,41 +88,96 @@ public void unsubscribe() {
7888

7989
// on subscribe add it to the map of outbound observers to notify
8090
observers.put(subscription, observer);
91+
92+
// check terminal state again
93+
s = checkTerminalState(observer);
94+
if(s != null) return s;
95+
96+
/**
97+
* NOTE: There is a race condition here.
98+
*
99+
* 1) terminal state gets set in onError or onCompleted
100+
* 2) observers.put adds a new observer
101+
* 3) checkTerminalState emits onError/onCompleted
102+
* 4) onError or onCompleted also emits onError/onCompleted since it was adds to observers
103+
*
104+
* Thus the terminal state could end up being sent twice.
105+
*
106+
* I'm going to leave this for now as AtomicObserver will protect against this
107+
* and I'd rather not add blocking synchronization in here unless the above race condition
108+
* truly is an issue.
109+
*/
110+
81111
return subscription;
82112
}
113+
114+
private Subscription checkTerminalState(Observer<T> observer) {
115+
Notification<T> n = terminalState.get();
116+
if (n != null) {
117+
// we are terminated to immediately emit and don't continue with subscription
118+
if (n.isOnCompleted()) {
119+
observer.onCompleted();
120+
} else {
121+
observer.onError(n.getException());
122+
}
123+
return Subscriptions.empty();
124+
} else {
125+
return null;
126+
}
127+
}
83128
};
84129

85-
return new PublishSubject<T>(onSubscribe, observers);
130+
return new PublishSubject<T>(onSubscribe, observers, terminalState);
86131
}
87132

88133
private final ConcurrentHashMap<Subscription, Observer<T>> observers;
134+
private final AtomicReference<Notification<T>> terminalState;
89135

90-
protected PublishSubject(Func1<Observer<T>, Subscription> onSubscribe, ConcurrentHashMap<Subscription, Observer<T>> observers) {
136+
protected PublishSubject(Func1<Observer<T>, Subscription> onSubscribe, ConcurrentHashMap<Subscription, Observer<T>> observers, AtomicReference<Notification<T>> terminalState) {
91137
super(onSubscribe);
92138
this.observers = observers;
139+
this.terminalState = terminalState;
93140
}
94141

95142
@Override
96143
public void onCompleted() {
97-
for (Observer<T> observer : observers.values()) {
144+
terminalState.set(new Notification<T>());
145+
for (Observer<T> observer : snapshotOfValues()) {
98146
observer.onCompleted();
99147
}
148+
observers.clear();
100149
}
101150

102151
@Override
103152
public void onError(Exception e) {
104-
for (Observer<T> observer : observers.values()) {
153+
terminalState.set(new Notification<T>(e));
154+
for (Observer<T> observer : snapshotOfValues()) {
105155
observer.onError(e);
106156
}
157+
observers.clear();
107158
}
108159

109160
@Override
110161
public void onNext(T args) {
111-
for (Observer<T> observer : observers.values()) {
162+
for (Observer<T> observer : snapshotOfValues()) {
112163
observer.onNext(args);
113164
}
114165
}
115166

167+
/**
168+
* Current snapshot of 'values()' so that concurrent modifications aren't included.
169+
*
170+
* This makes it behave deterministically in a single-threaded execution when nesting subscribes.
171+
*
172+
* In multi-threaded execution it will cause new subscriptions to wait until the following onNext instead
173+
* of possibly being included in the current onNext iteration.
174+
*
175+
* @return List<Observer<T>>
176+
*/
177+
private Collection<Observer<T>> snapshotOfValues() {
178+
return new ArrayList<Observer<T>>(observers.values());
179+
}
180+
116181
public static class UnitTest {
117182
@Test
118183
public void test() {
@@ -307,6 +372,75 @@ private void assertObservedUntilTwo(Observer<String> aObserver)
307372
verify(aObserver, Mockito.never()).onCompleted();
308373
}
309374

375+
/**
376+
* Test that subscribing after onError/onCompleted immediately terminates instead of causing it to hang.
377+
*
378+
* Nothing is mentioned in Rx Guidelines for what to do in this case so I'm doing what seems to make sense
379+
* which is:
380+
*
381+
* - cache terminal state (onError/onCompleted)
382+
* - any subsequent subscriptions will immediately receive the terminal state rather than start a new subscription
383+
*
384+
*/
385+
@Test
386+
public void testUnsubscribeAfterOnCompleted() {
387+
PublishSubject<Object> subject = PublishSubject.create();
388+
389+
@SuppressWarnings("unchecked")
390+
Observer<String> anObserver = mock(Observer.class);
391+
subject.subscribe(anObserver);
392+
393+
subject.onNext("one");
394+
subject.onNext("two");
395+
subject.onCompleted();
396+
397+
InOrder inOrder = inOrder(anObserver);
398+
inOrder.verify(anObserver, times(1)).onNext("one");
399+
inOrder.verify(anObserver, times(1)).onNext("two");
400+
inOrder.verify(anObserver, times(1)).onCompleted();
401+
inOrder.verify(anObserver, Mockito.never()).onError(any(Exception.class));
402+
403+
@SuppressWarnings("unchecked")
404+
Observer<String> anotherObserver = mock(Observer.class);
405+
subject.subscribe(anotherObserver);
406+
407+
inOrder = inOrder(anotherObserver);
408+
inOrder.verify(anotherObserver, Mockito.never()).onNext("one");
409+
inOrder.verify(anotherObserver, Mockito.never()).onNext("two");
410+
inOrder.verify(anotherObserver, times(1)).onCompleted();
411+
inOrder.verify(anotherObserver, Mockito.never()).onError(any(Exception.class));
412+
}
413+
414+
@Test
415+
public void testUnsubscribeAfterOnError() {
416+
PublishSubject<Object> subject = PublishSubject.create();
417+
RuntimeException exception = new RuntimeException("failure");
418+
419+
@SuppressWarnings("unchecked")
420+
Observer<String> anObserver = mock(Observer.class);
421+
subject.subscribe(anObserver);
422+
423+
subject.onNext("one");
424+
subject.onNext("two");
425+
subject.onError(exception);
426+
427+
InOrder inOrder = inOrder(anObserver);
428+
inOrder.verify(anObserver, times(1)).onNext("one");
429+
inOrder.verify(anObserver, times(1)).onNext("two");
430+
inOrder.verify(anObserver, times(1)).onError(exception);
431+
inOrder.verify(anObserver, Mockito.never()).onCompleted();
432+
433+
@SuppressWarnings("unchecked")
434+
Observer<String> anotherObserver = mock(Observer.class);
435+
subject.subscribe(anotherObserver);
436+
437+
inOrder = inOrder(anotherObserver);
438+
inOrder.verify(anotherObserver, Mockito.never()).onNext("one");
439+
inOrder.verify(anotherObserver, Mockito.never()).onNext("two");
440+
inOrder.verify(anotherObserver, times(1)).onError(exception);
441+
inOrder.verify(anotherObserver, Mockito.never()).onCompleted();
442+
}
443+
310444
@Test
311445
public void testUnsubscribe()
312446
{
@@ -340,5 +474,58 @@ public void call(PublishSubject<Object> DefaultSubject)
340474
}
341475
});
342476
}
477+
478+
@Test
479+
public void testNestedSubscribe() {
480+
final PublishSubject<Integer> s = PublishSubject.create();
481+
482+
final AtomicInteger countParent = new AtomicInteger();
483+
final AtomicInteger countChildren = new AtomicInteger();
484+
final AtomicInteger countTotal = new AtomicInteger();
485+
486+
final ArrayList<String> list = new ArrayList<String>();
487+
488+
s.mapMany(new Func1<Integer, Observable<String>>() {
489+
490+
@Override
491+
public Observable<String> call(final Integer v) {
492+
countParent.incrementAndGet();
493+
494+
// then subscribe to subject again (it will not receive the previous value)
495+
return s.map(new Func1<Integer, String>() {
496+
497+
@Override
498+
public String call(Integer v2) {
499+
countChildren.incrementAndGet();
500+
return "Parent: " + v + " Child: " + v2;
501+
}
502+
503+
});
504+
}
505+
506+
}).subscribe(new Action1<String>() {
507+
508+
@Override
509+
public void call(String v) {
510+
countTotal.incrementAndGet();
511+
list.add(v);
512+
}
513+
514+
});
515+
516+
517+
for(int i=0; i<10; i++) {
518+
s.onNext(i);
519+
}
520+
s.onCompleted();
521+
522+
// System.out.println("countParent: " + countParent.get());
523+
// System.out.println("countChildren: " + countChildren.get());
524+
// System.out.println("countTotal: " + countTotal.get());
525+
526+
// 9+8+7+6+5+4+3+2+1+0 == 45
527+
assertEquals(45, list.size());
528+
}
529+
343530
}
344531
}

0 commit comments

Comments
 (0)