15
15
*/
16
16
package rx .subjects ;
17
17
18
+ import static org .junit .Assert .*;
18
19
import static org .mockito .Matchers .*;
19
20
import static org .mockito .Mockito .*;
20
21
21
22
import java .util .ArrayList ;
23
+ import java .util .Collection ;
22
24
import java .util .List ;
23
25
import java .util .concurrent .ConcurrentHashMap ;
24
26
import java .util .concurrent .atomic .AtomicBoolean ;
27
+ import java .util .concurrent .atomic .AtomicInteger ;
25
28
import java .util .concurrent .atomic .AtomicReference ;
26
29
27
30
import junit .framework .Assert ;
28
31
29
32
import org .junit .Test ;
33
+ import org .mockito .InOrder ;
30
34
import org .mockito .Mockito ;
31
35
32
36
import rx .Notification ;
33
37
import rx .Observable ;
34
38
import rx .Observer ;
35
39
import rx .Subscription ;
36
40
import rx .operators .AtomicObservableSubscription ;
41
+ import rx .subscriptions .Subscriptions ;
37
42
import rx .util .functions .Action1 ;
38
43
import rx .util .functions .Func0 ;
39
44
import rx .util .functions .Func1 ;
62
67
public class PublishSubject <T > extends Subject <T , T > {
63
68
public static <T > PublishSubject <T > create () {
64
69
final ConcurrentHashMap <Subscription , Observer <T >> observers = new ConcurrentHashMap <Subscription , Observer <T >>();
65
-
70
+ final AtomicReference <Notification <T >> terminalState = new AtomicReference <Notification <T >>();
71
+
66
72
Func1 <Observer <T >, Subscription > onSubscribe = new Func1 <Observer <T >, Subscription >() {
67
73
@ Override
68
74
public Subscription call (Observer <T > observer ) {
75
+ // first check if terminal state exist
76
+ Subscription s = checkTerminalState (observer );
77
+ if (s != null ) return s ;
78
+
69
79
final AtomicObservableSubscription subscription = new AtomicObservableSubscription ();
70
80
71
81
subscription .wrap (new Subscription () {
@@ -78,41 +88,96 @@ public void unsubscribe() {
78
88
79
89
// on subscribe add it to the map of outbound observers to notify
80
90
observers .put (subscription , observer );
91
+
92
+ // check terminal state again
93
+ s = checkTerminalState (observer );
94
+ if (s != null ) return s ;
95
+
96
+ /**
97
+ * NOTE: There is a race condition here.
98
+ *
99
+ * 1) terminal state gets set in onError or onCompleted
100
+ * 2) observers.put adds a new observer
101
+ * 3) checkTerminalState emits onError/onCompleted
102
+ * 4) onError or onCompleted also emits onError/onCompleted since it was adds to observers
103
+ *
104
+ * Thus the terminal state could end up being sent twice.
105
+ *
106
+ * I'm going to leave this for now as AtomicObserver will protect against this
107
+ * and I'd rather not add blocking synchronization in here unless the above race condition
108
+ * truly is an issue.
109
+ */
110
+
81
111
return subscription ;
82
112
}
113
+
114
+ private Subscription checkTerminalState (Observer <T > observer ) {
115
+ Notification <T > n = terminalState .get ();
116
+ if (n != null ) {
117
+ // we are terminated to immediately emit and don't continue with subscription
118
+ if (n .isOnCompleted ()) {
119
+ observer .onCompleted ();
120
+ } else {
121
+ observer .onError (n .getException ());
122
+ }
123
+ return Subscriptions .empty ();
124
+ } else {
125
+ return null ;
126
+ }
127
+ }
83
128
};
84
129
85
- return new PublishSubject <T >(onSubscribe , observers );
130
+ return new PublishSubject <T >(onSubscribe , observers , terminalState );
86
131
}
87
132
88
133
private final ConcurrentHashMap <Subscription , Observer <T >> observers ;
134
+ private final AtomicReference <Notification <T >> terminalState ;
89
135
90
- protected PublishSubject (Func1 <Observer <T >, Subscription > onSubscribe , ConcurrentHashMap <Subscription , Observer <T >> observers ) {
136
+ protected PublishSubject (Func1 <Observer <T >, Subscription > onSubscribe , ConcurrentHashMap <Subscription , Observer <T >> observers , AtomicReference < Notification < T >> terminalState ) {
91
137
super (onSubscribe );
92
138
this .observers = observers ;
139
+ this .terminalState = terminalState ;
93
140
}
94
141
95
142
@ Override
96
143
public void onCompleted () {
97
- for (Observer <T > observer : observers .values ()) {
144
+ terminalState .set (new Notification <T >());
145
+ for (Observer <T > observer : snapshotOfValues ()) {
98
146
observer .onCompleted ();
99
147
}
148
+ observers .clear ();
100
149
}
101
150
102
151
@ Override
103
152
public void onError (Exception e ) {
104
- for (Observer <T > observer : observers .values ()) {
153
+ terminalState .set (new Notification <T >(e ));
154
+ for (Observer <T > observer : snapshotOfValues ()) {
105
155
observer .onError (e );
106
156
}
157
+ observers .clear ();
107
158
}
108
159
109
160
@ Override
110
161
public void onNext (T args ) {
111
- for (Observer <T > observer : observers . values ()) {
162
+ for (Observer <T > observer : snapshotOfValues ()) {
112
163
observer .onNext (args );
113
164
}
114
165
}
115
166
167
+ /**
168
+ * Current snapshot of 'values()' so that concurrent modifications aren't included.
169
+ *
170
+ * This makes it behave deterministically in a single-threaded execution when nesting subscribes.
171
+ *
172
+ * In multi-threaded execution it will cause new subscriptions to wait until the following onNext instead
173
+ * of possibly being included in the current onNext iteration.
174
+ *
175
+ * @return List<Observer<T>>
176
+ */
177
+ private Collection <Observer <T >> snapshotOfValues () {
178
+ return new ArrayList <Observer <T >>(observers .values ());
179
+ }
180
+
116
181
public static class UnitTest {
117
182
@ Test
118
183
public void test () {
@@ -307,6 +372,75 @@ private void assertObservedUntilTwo(Observer<String> aObserver)
307
372
verify (aObserver , Mockito .never ()).onCompleted ();
308
373
}
309
374
375
+ /**
376
+ * Test that subscribing after onError/onCompleted immediately terminates instead of causing it to hang.
377
+ *
378
+ * Nothing is mentioned in Rx Guidelines for what to do in this case so I'm doing what seems to make sense
379
+ * which is:
380
+ *
381
+ * - cache terminal state (onError/onCompleted)
382
+ * - any subsequent subscriptions will immediately receive the terminal state rather than start a new subscription
383
+ *
384
+ */
385
+ @ Test
386
+ public void testUnsubscribeAfterOnCompleted () {
387
+ PublishSubject <Object > subject = PublishSubject .create ();
388
+
389
+ @ SuppressWarnings ("unchecked" )
390
+ Observer <String > anObserver = mock (Observer .class );
391
+ subject .subscribe (anObserver );
392
+
393
+ subject .onNext ("one" );
394
+ subject .onNext ("two" );
395
+ subject .onCompleted ();
396
+
397
+ InOrder inOrder = inOrder (anObserver );
398
+ inOrder .verify (anObserver , times (1 )).onNext ("one" );
399
+ inOrder .verify (anObserver , times (1 )).onNext ("two" );
400
+ inOrder .verify (anObserver , times (1 )).onCompleted ();
401
+ inOrder .verify (anObserver , Mockito .never ()).onError (any (Exception .class ));
402
+
403
+ @ SuppressWarnings ("unchecked" )
404
+ Observer <String > anotherObserver = mock (Observer .class );
405
+ subject .subscribe (anotherObserver );
406
+
407
+ inOrder = inOrder (anotherObserver );
408
+ inOrder .verify (anotherObserver , Mockito .never ()).onNext ("one" );
409
+ inOrder .verify (anotherObserver , Mockito .never ()).onNext ("two" );
410
+ inOrder .verify (anotherObserver , times (1 )).onCompleted ();
411
+ inOrder .verify (anotherObserver , Mockito .never ()).onError (any (Exception .class ));
412
+ }
413
+
414
+ @ Test
415
+ public void testUnsubscribeAfterOnError () {
416
+ PublishSubject <Object > subject = PublishSubject .create ();
417
+ RuntimeException exception = new RuntimeException ("failure" );
418
+
419
+ @ SuppressWarnings ("unchecked" )
420
+ Observer <String > anObserver = mock (Observer .class );
421
+ subject .subscribe (anObserver );
422
+
423
+ subject .onNext ("one" );
424
+ subject .onNext ("two" );
425
+ subject .onError (exception );
426
+
427
+ InOrder inOrder = inOrder (anObserver );
428
+ inOrder .verify (anObserver , times (1 )).onNext ("one" );
429
+ inOrder .verify (anObserver , times (1 )).onNext ("two" );
430
+ inOrder .verify (anObserver , times (1 )).onError (exception );
431
+ inOrder .verify (anObserver , Mockito .never ()).onCompleted ();
432
+
433
+ @ SuppressWarnings ("unchecked" )
434
+ Observer <String > anotherObserver = mock (Observer .class );
435
+ subject .subscribe (anotherObserver );
436
+
437
+ inOrder = inOrder (anotherObserver );
438
+ inOrder .verify (anotherObserver , Mockito .never ()).onNext ("one" );
439
+ inOrder .verify (anotherObserver , Mockito .never ()).onNext ("two" );
440
+ inOrder .verify (anotherObserver , times (1 )).onError (exception );
441
+ inOrder .verify (anotherObserver , Mockito .never ()).onCompleted ();
442
+ }
443
+
310
444
@ Test
311
445
public void testUnsubscribe ()
312
446
{
@@ -340,5 +474,58 @@ public void call(PublishSubject<Object> DefaultSubject)
340
474
}
341
475
});
342
476
}
477
+
478
+ @ Test
479
+ public void testNestedSubscribe () {
480
+ final PublishSubject <Integer > s = PublishSubject .create ();
481
+
482
+ final AtomicInteger countParent = new AtomicInteger ();
483
+ final AtomicInteger countChildren = new AtomicInteger ();
484
+ final AtomicInteger countTotal = new AtomicInteger ();
485
+
486
+ final ArrayList <String > list = new ArrayList <String >();
487
+
488
+ s .mapMany (new Func1 <Integer , Observable <String >>() {
489
+
490
+ @ Override
491
+ public Observable <String > call (final Integer v ) {
492
+ countParent .incrementAndGet ();
493
+
494
+ // then subscribe to subject again (it will not receive the previous value)
495
+ return s .map (new Func1 <Integer , String >() {
496
+
497
+ @ Override
498
+ public String call (Integer v2 ) {
499
+ countChildren .incrementAndGet ();
500
+ return "Parent: " + v + " Child: " + v2 ;
501
+ }
502
+
503
+ });
504
+ }
505
+
506
+ }).subscribe (new Action1 <String >() {
507
+
508
+ @ Override
509
+ public void call (String v ) {
510
+ countTotal .incrementAndGet ();
511
+ list .add (v );
512
+ }
513
+
514
+ });
515
+
516
+
517
+ for (int i =0 ; i <10 ; i ++) {
518
+ s .onNext (i );
519
+ }
520
+ s .onCompleted ();
521
+
522
+ // System.out.println("countParent: " + countParent.get());
523
+ // System.out.println("countChildren: " + countChildren.get());
524
+ // System.out.println("countTotal: " + countTotal.get());
525
+
526
+ // 9+8+7+6+5+4+3+2+1+0 == 45
527
+ assertEquals (45 , list .size ());
528
+ }
529
+
343
530
}
344
531
}
0 commit comments