Skip to content

Commit 381b2eb

Browse files
authored
Add StateFlow<T>.onSubscription (#4380)
Fixes #4275
1 parent e5a7b42 commit 381b2eb

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

kotlinx-coroutines-core/api/kotlinx-coroutines-core.api

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,7 @@ public final class kotlinx/coroutines/flow/FlowKt {
10981098
public static synthetic fun onErrorReturn$default (Lkotlinx/coroutines/flow/Flow;Ljava/lang/Object;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lkotlinx/coroutines/flow/Flow;
10991099
public static final fun onStart (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
11001100
public static final fun onSubscription (Lkotlinx/coroutines/flow/SharedFlow;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/SharedFlow;
1101+
public static final fun onSubscription (Lkotlinx/coroutines/flow/StateFlow;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/StateFlow;
11011102
public static final fun produceIn (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/CoroutineScope;)Lkotlinx/coroutines/channels/ReceiveChannel;
11021103
public static final fun publish (Lkotlinx/coroutines/flow/Flow;)Lkotlinx/coroutines/flow/Flow;
11031104
public static final fun publish (Lkotlinx/coroutines/flow/Flow;I)Lkotlinx/coroutines/flow/Flow;

kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,7 @@ final fun <#A: kotlin/Any?> (kotlinx.coroutines.flow/SharedFlow<#A>).kotlinx.cor
904904
final fun <#A: kotlin/Any?> (kotlinx.coroutines.flow/SharedFlow<#A>).kotlinx.coroutines.flow/onSubscription(kotlin.coroutines/SuspendFunction1<kotlinx.coroutines.flow/FlowCollector<#A>, kotlin/Unit>): kotlinx.coroutines.flow/SharedFlow<#A> // kotlinx.coroutines.flow/onSubscription|[email protected]<0:0>(kotlin.coroutines.SuspendFunction1<kotlinx.coroutines.flow.FlowCollector<0:0>,kotlin.Unit>){0§<kotlin.Any?>}[0]
905905
final fun <#A: kotlin/Any?> (kotlinx.coroutines.flow/StateFlow<#A>).kotlinx.coroutines.flow/conflate(): kotlinx.coroutines.flow/Flow<#A> // kotlinx.coroutines.flow/conflate|[email protected]<0:0>(){0§<kotlin.Any?>}[0]
906906
final fun <#A: kotlin/Any?> (kotlinx.coroutines.flow/StateFlow<#A>).kotlinx.coroutines.flow/distinctUntilChanged(): kotlinx.coroutines.flow/Flow<#A> // kotlinx.coroutines.flow/distinctUntilChanged|[email protected]<0:0>(){0§<kotlin.Any?>}[0]
907+
final fun <#A: kotlin/Any?> (kotlinx.coroutines.flow/StateFlow<#A>).kotlinx.coroutines.flow/onSubscription(kotlin.coroutines/SuspendFunction1<kotlinx.coroutines.flow/FlowCollector<#A>, kotlin/Unit>): kotlinx.coroutines.flow/StateFlow<#A> // kotlinx.coroutines.flow/onSubscription|[email protected]<0:0>(kotlin.coroutines.SuspendFunction1<kotlinx.coroutines.flow.FlowCollector<0:0>,kotlin.Unit>){0§<kotlin.Any?>}[0]
907908
final fun <#A: kotlin/Any?> (kotlinx.coroutines.selects/SelectBuilder<#A>).kotlinx.coroutines.selects/onTimeout(kotlin.time/Duration, kotlin.coroutines/SuspendFunction0<#A>) // kotlinx.coroutines.selects/onTimeout|[email protected]<0:0>(kotlin.time.Duration;kotlin.coroutines.SuspendFunction0<0:0>){0§<kotlin.Any?>}[0]
908909
final fun <#A: kotlin/Any?> (kotlinx.coroutines.selects/SelectBuilder<#A>).kotlinx.coroutines.selects/onTimeout(kotlin/Long, kotlin.coroutines/SuspendFunction0<#A>) // kotlinx.coroutines.selects/onTimeout|[email protected]<0:0>(kotlin.Long;kotlin.coroutines.SuspendFunction0<0:0>){0§<kotlin.Any?>}[0]
909910
final fun <#A: kotlin/Any?> (kotlinx.coroutines/CompletableDeferred<#A>).kotlinx.coroutines/completeWith(kotlin/Result<#A>): kotlin/Boolean // kotlinx.coroutines/completeWith|[email protected]<0:0>(kotlin.Result<0:0>){0§<kotlin.Any?>}[0]

kotlinx-coroutines-core/common/src/flow/operators/Share.kt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,29 @@ private class SubscribedSharedFlow<T>(
412412
sharedFlow.collect(SubscribedFlowCollector(collector, action))
413413
}
414414

415+
/**
416+
* Returns a flow that invokes the given [action] **after** this state flow starts to be collected
417+
* (after the subscription is registered).
418+
*
419+
* The [action] is called before any value is emitted from the upstream
420+
* flow to this subscription but after the subscription is established. It is guaranteed that all emissions to
421+
* the upstream flow that happen inside or immediately after this `onSubscription` action will be
422+
* collected by this subscription.
423+
*
424+
* The receiver of the [action] is [FlowCollector], so `onSubscription` can emit additional elements.
425+
*/
426+
public fun <T> StateFlow<T>.onSubscription(action: suspend FlowCollector<T>.() -> Unit): StateFlow<T> =
427+
SubscribedStateFlow(this, action)
428+
429+
@OptIn(ExperimentalForInheritanceCoroutinesApi::class)
430+
private class SubscribedStateFlow<T>(
431+
private val stateFlow: StateFlow<T>,
432+
private val action: suspend FlowCollector<T>.() -> Unit
433+
) : StateFlow<T> by stateFlow {
434+
override suspend fun collect(collector: FlowCollector<T>) =
435+
stateFlow.collect(SubscribedFlowCollector(collector, action))
436+
}
437+
415438
internal class SubscribedFlowCollector<T>(
416439
private val collector: FlowCollector<T>,
417440
private val action: suspend FlowCollector<T>.() -> Unit

kotlinx-coroutines-core/common/test/flow/sharing/StateFlowTest.kt

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,67 @@ class StateFlowTest : TestBase() {
110110
}
111111
}
112112

113+
@Test
114+
fun testOnSubscription() = runTest {
115+
expect(1)
116+
val state = MutableStateFlow("initial") // "initial" gets lost, replaced by "D"
117+
state
118+
.onSubscription {
119+
emit("collector->A")
120+
state.value = "A" // gets lost, replaced by "B"
121+
}
122+
.onSubscription {
123+
emit("collector->B")
124+
state.value = "B"
125+
}
126+
.onStart {
127+
emit("collector->C")
128+
state.value = "C" // gets lost, replaced by "A"
129+
}
130+
.onStart {
131+
emit("collector->D")
132+
state.value = "D" // gets lost, replaced by "C"
133+
}
134+
.onEach {
135+
when (it) {
136+
"collector->D" -> expect(2)
137+
"collector->C" -> expect(3)
138+
"collector->A" -> expect(4)
139+
"collector->B" -> expect(5)
140+
"B" -> {
141+
expect(6)
142+
currentCoroutineContext().cancel()
143+
}
144+
else -> expectUnreached()
145+
}
146+
}
147+
.launchIn(this)
148+
.join()
149+
finish(7)
150+
}
151+
152+
@Test
153+
@Suppress("DEPRECATION") // 'catch'
154+
fun testOnSubscriptionThrows() = runTest {
155+
expect(1)
156+
val state = MutableStateFlow("initial")
157+
state
158+
.onSubscription {
159+
expect(2)
160+
throw TestException()
161+
}
162+
.catch { e ->
163+
assertIs<TestException>(e)
164+
expect(3)
165+
}
166+
.collect {
167+
// onSubscription throws before "initial" is emitted, so no value is collected
168+
expectUnreached()
169+
}
170+
assertEquals(0, state.subscriptionCount.value)
171+
finish(4)
172+
}
173+
113174
@Test
114175
public fun testOnSubscriptionWithException() = runTest {
115176
expect(1)

0 commit comments

Comments
 (0)