Skip to content

Commit 5059158

Browse files
authored
Fix exception propagation in Async API methods (#1479)
- Resolve an issue where exceptions thrown during thenRun, thenSupply, and related operations in the asynchronous API were not properly propagated to the completion callback. This issue was addressed by replacing `unsafeFinish` with `finish`, ensuring that exceptions are caught and correctly passed to the completion callback when executed on different threads. - Update existing Async API tests to ensure they simulate separate async thread execution. - Modify the async callback to catch and handle exceptions locally. Exceptions are now directly processed and passed as an error argument to the callback function, avoiding propagation to the parent callback. - Move `callback.onResult` outside the catch block to ensure it's not invoked twice when an exception occurs. JAVA-5562
1 parent f2cfac7 commit 5059158

9 files changed

+343
-57
lines changed

driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import com.mongodb.lang.Nullable;
2020

21+
import java.util.concurrent.atomic.AtomicBoolean;
22+
2123
/**
2224
* See {@link AsyncRunnable}
2325
* <p>
@@ -33,4 +35,28 @@ public interface AsyncFunction<T, R> {
3335
* @param callback the callback
3436
*/
3537
void unsafeFinish(T value, SingleResultCallback<R> callback);
38+
39+
/**
40+
* Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
41+
*
42+
* @param callback the callback provided by the method the chain is used in.
43+
*/
44+
default void finish(final T value, final SingleResultCallback<R> callback) {
45+
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
46+
try {
47+
this.unsafeFinish(value, (v, e) -> {
48+
if (!callbackInvoked.compareAndSet(false, true)) {
49+
throw new AssertionError(String.format("Callback has been already completed. It could happen "
50+
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
51+
}
52+
callback.onResult(v, e);
53+
});
54+
} catch (Throwable t) {
55+
if (!callbackInvoked.compareAndSet(false, true)) {
56+
throw t;
57+
} else {
58+
callback.completeExceptionally(t);
59+
}
60+
}
61+
}
3662
}

driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) {
171171
return (c) -> {
172172
this.unsafeFinish((r, e) -> {
173173
if (e == null) {
174-
runnable.unsafeFinish(c);
174+
/* If 'runnable' is executed on a different thread from the one that executed the initial 'finish()',
175+
then invoking 'finish()' within 'runnable' will catch and propagate any exceptions to 'c' (the callback). */
176+
runnable.finish(c);
175177
} else {
176178
c.completeExceptionally(e);
177179
}
@@ -236,7 +238,7 @@ default AsyncRunnable thenRunIf(final Supplier<Boolean> condition, final AsyncRu
236238
return;
237239
}
238240
if (matched) {
239-
runnable.unsafeFinish(callback);
241+
runnable.finish(callback);
240242
} else {
241243
callback.complete(callback);
242244
}
@@ -253,7 +255,7 @@ default <R> AsyncSupplier<R> thenSupply(final AsyncSupplier<R> supplier) {
253255
return (c) -> {
254256
this.unsafeFinish((r, e) -> {
255257
if (e == null) {
256-
supplier.unsafeFinish(c);
258+
supplier.finish(c);
257259
} else {
258260
c.completeExceptionally(e);
259261
}

driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import com.mongodb.lang.Nullable;
2020

21+
import java.util.concurrent.atomic.AtomicBoolean;
2122
import java.util.function.Predicate;
2223

2324

@@ -54,18 +55,25 @@ default void unsafeFinish(@Nullable final Void value, final SingleResultCallback
5455
}
5556

5657
/**
57-
* Must be invoked at end of async chain.
58+
* Must be invoked at end of async chain or when executing a callback handler supplied by the caller.
59+
*
60+
* @see #thenApply(AsyncFunction)
61+
* @see #thenConsume(AsyncConsumer)
62+
* @see #onErrorIf(Predicate, AsyncFunction)
5863
* @param callback the callback provided by the method the chain is used in
5964
*/
6065
default void finish(final SingleResultCallback<T> callback) {
61-
final boolean[] callbackInvoked = {false};
66+
final AtomicBoolean callbackInvoked = new AtomicBoolean(false);
6267
try {
6368
this.unsafeFinish((v, e) -> {
64-
callbackInvoked[0] = true;
69+
if (!callbackInvoked.compareAndSet(false, true)) {
70+
throw new AssertionError(String.format("Callback has been already completed. It could happen "
71+
+ "if code throws an exception after invoking an async method. Value: %s", v), e);
72+
}
6573
callback.onResult(v, e);
6674
});
6775
} catch (Throwable t) {
68-
if (callbackInvoked[0]) {
76+
if (!callbackInvoked.compareAndSet(false, true)) {
6977
throw t;
7078
} else {
7179
callback.completeExceptionally(t);
@@ -80,9 +88,9 @@ default void finish(final SingleResultCallback<T> callback) {
8088
*/
8189
default <R> AsyncSupplier<R> thenApply(final AsyncFunction<T, R> function) {
8290
return (c) -> {
83-
this.unsafeFinish((v, e) -> {
91+
this.finish((v, e) -> {
8492
if (e == null) {
85-
function.unsafeFinish(v, c);
93+
function.finish(v, c);
8694
} else {
8795
c.completeExceptionally(e);
8896
}
@@ -99,7 +107,7 @@ default AsyncRunnable thenConsume(final AsyncConsumer<T> consumer) {
99107
return (c) -> {
100108
this.unsafeFinish((v, e) -> {
101109
if (e == null) {
102-
consumer.unsafeFinish(v, c);
110+
consumer.finish(v, c);
103111
} else {
104112
c.completeExceptionally(e);
105113
}
@@ -131,7 +139,7 @@ default AsyncSupplier<T> onErrorIf(
131139
return;
132140
}
133141
if (errorMatched) {
134-
errorFunction.unsafeFinish(e, callback);
142+
errorFunction.finish(e, callback);
135143
} else {
136144
callback.completeExceptionally(e);
137145
}

driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,7 @@ private <T> void sendCommandMessageAsync(final int messageId, final Decoder<T> d
610610
return;
611611
}
612612
assertNotNull(responseBuffers);
613+
T commandResult;
613614
try {
614615
updateSessionContext(operationContext.getSessionContext(), responseBuffers);
615616
boolean commandOk =
@@ -624,13 +625,14 @@ private <T> void sendCommandMessageAsync(final int messageId, final Decoder<T> d
624625
}
625626
commandEventSender.sendSucceededEvent(responseBuffers);
626627

627-
T result1 = getCommandResult(decoder, responseBuffers, messageId, operationContext.getTimeoutContext());
628-
callback.onResult(result1, null);
628+
commandResult = getCommandResult(decoder, responseBuffers, messageId, operationContext.getTimeoutContext());
629629
} catch (Throwable localThrowable) {
630630
callback.onResult(null, localThrowable);
631+
return;
631632
} finally {
632633
responseBuffers.close();
633634
}
635+
callback.onResult(commandResult, null);
634636
}));
635637
}
636638
});

driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,14 @@ public void startHandshakeAsync(final InternalConnection internalConnection, fin
101101
callback.onResult(null, t instanceof MongoException ? mapHelloException((MongoException) t) : t);
102102
} else {
103103
setSpeculativeAuthenticateResponse(helloResult);
104-
callback.onResult(createInitializationDescription(helloResult, internalConnection, startTime), null);
104+
InternalConnectionInitializationDescription initializationDescription;
105+
try {
106+
initializationDescription = createInitializationDescription(helloResult, internalConnection, startTime);
107+
} catch (Throwable localThrowable) {
108+
callback.onResult(null, localThrowable);
109+
return;
110+
}
111+
callback.onResult(initializationDescription, null);
105112
}
106113
});
107114
}

driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java renamed to driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsAbstractTest.java

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525

2626
import static com.mongodb.assertions.Assertions.assertNotNull;
2727
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
28-
import static org.junit.jupiter.api.Assertions.assertThrows;
2928

30-
final class AsyncFunctionsTest extends AsyncFunctionsTestAbstract {
29+
abstract class AsyncFunctionsAbstractTest extends AsyncFunctionsTestBase {
3130
private static final TimeoutContext TIMEOUT_CONTEXT = new TimeoutContext(new TimeoutSettings(0, 0, 0, 0L, 0));
31+
3232
@Test
3333
void test1Method() {
3434
// the number of expected variations is often: 1 + N methods invoked
@@ -760,25 +760,6 @@ void testVariables() {
760760
});
761761
}
762762

763-
@Test
764-
void testInvalid() {
765-
setIsTestingAbruptCompletion(false);
766-
setAsyncStep(true);
767-
assertThrows(IllegalStateException.class, () -> {
768-
beginAsync().thenRun(c -> {
769-
async(3, c);
770-
throw new IllegalStateException("must not cause second callback invocation");
771-
}).finish((v, e) -> {});
772-
});
773-
assertThrows(IllegalStateException.class, () -> {
774-
beginAsync().thenRun(c -> {
775-
async(3, c);
776-
}).finish((v, e) -> {
777-
throw new IllegalStateException("must not cause second callback invocation");
778-
});
779-
});
780-
}
781-
782763
@Test
783764
void testDerivation() {
784765
// Demonstrates the progression from nested async to the API.
@@ -866,5 +847,4 @@ void testDerivation() {
866847
}).finish(callback);
867848
});
868849
}
869-
870850
}

driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestAbstract.java renamed to driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTestBase.java

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,17 @@
1717
package com.mongodb.internal.async;
1818

1919
import com.mongodb.client.TestListener;
20+
import org.junit.jupiter.api.AfterEach;
21+
import org.junit.jupiter.api.BeforeEach;
2022
import org.opentest4j.AssertionFailedError;
2123

2224
import java.util.ArrayList;
2325
import java.util.List;
24-
import java.util.concurrent.atomic.AtomicBoolean;
26+
import java.util.concurrent.CompletableFuture;
27+
import java.util.concurrent.ExecutionException;
28+
import java.util.concurrent.ExecutorService;
29+
import java.util.concurrent.TimeUnit;
30+
import java.util.concurrent.TimeoutException;
2531
import java.util.concurrent.atomic.AtomicReference;
2632
import java.util.function.Consumer;
2733
import java.util.function.Supplier;
@@ -31,11 +37,12 @@
3137
import static org.junit.jupiter.api.Assertions.assertTrue;
3238
import static org.junit.jupiter.api.Assertions.fail;
3339

34-
public class AsyncFunctionsTestAbstract {
40+
public abstract class AsyncFunctionsTestBase {
3541

3642
private final TestListener listener = new TestListener();
3743
private final InvocationTracker invocationTracker = new InvocationTracker();
3844
private boolean isTestingAbruptCompletion = false;
45+
private ExecutorService asyncExecutor;
3946

4047
void setIsTestingAbruptCompletion(final boolean b) {
4148
isTestingAbruptCompletion = b;
@@ -53,6 +60,23 @@ public void listenerAdd(final String s) {
5360
listener.add(s);
5461
}
5562

63+
/**
64+
* Create an executor service for async operations before each test.
65+
*
66+
* @return the executor service.
67+
*/
68+
public abstract ExecutorService createAsyncExecutor();
69+
70+
@BeforeEach
71+
public void setUp() {
72+
asyncExecutor = createAsyncExecutor();
73+
}
74+
75+
@AfterEach
76+
public void shutDown() {
77+
asyncExecutor.shutdownNow();
78+
}
79+
5680
void plain(final int i) {
5781
int cur = invocationTracker.getNextOption(2);
5882
if (cur == 0) {
@@ -98,32 +122,47 @@ Integer syncReturns(final int i) {
98122
return affectedReturns(i);
99123
}
100124

125+
126+
public void submit(final Runnable task) {
127+
asyncExecutor.execute(task);
128+
}
101129
void async(final int i, final SingleResultCallback<Void> callback) {
102130
assertTrue(invocationTracker.isAsyncStep);
103131
if (isTestingAbruptCompletion) {
132+
/* We should not test for abrupt completion in a separate thread. Once a callback is registered for an async operation,
133+
the Async Framework does not handle exceptions thrown outside of callbacks by the executing thread. Such exception management
134+
should be the responsibility of the thread conducting the asynchronous operations. */
104135
affected(i);
105-
callback.complete(callback);
106-
107-
} else {
108-
try {
109-
affected(i);
136+
submit(() -> {
110137
callback.complete(callback);
111-
} catch (Throwable t) {
112-
callback.onResult(null, t);
113-
}
138+
});
139+
} else {
140+
submit(() -> {
141+
try {
142+
affected(i);
143+
callback.complete(callback);
144+
} catch (Throwable t) {
145+
callback.onResult(null, t);
146+
}
147+
});
114148
}
115149
}
116150

117151
void asyncReturns(final int i, final SingleResultCallback<Integer> callback) {
118152
assertTrue(invocationTracker.isAsyncStep);
119153
if (isTestingAbruptCompletion) {
120-
callback.complete(affectedReturns(i));
154+
int result = affectedReturns(i);
155+
submit(() -> {
156+
callback.complete(result);
157+
});
121158
} else {
122-
try {
123-
callback.complete(affectedReturns(i));
124-
} catch (Throwable t) {
125-
callback.onResult(null, t);
126-
}
159+
submit(() -> {
160+
try {
161+
callback.complete(affectedReturns(i));
162+
} catch (Throwable t) {
163+
callback.onResult(null, t);
164+
}
165+
});
127166
}
128167
}
129168

@@ -200,24 +239,26 @@ private <T> void assertBehavesSame(final Supplier<T> sync, final Runnable betwee
200239

201240
AtomicReference<T> actualValue = new AtomicReference<>();
202241
AtomicReference<Throwable> actualException = new AtomicReference<>();
203-
AtomicBoolean wasCalled = new AtomicBoolean(false);
242+
CompletableFuture<Void> wasCalledFuture = new CompletableFuture<>();
204243
try {
205244
async.accept((v, e) -> {
206245
actualValue.set(v);
207246
actualException.set(e);
208-
if (wasCalled.get()) {
247+
if (wasCalledFuture.isDone()) {
209248
fail();
210249
}
211-
wasCalled.set(true);
250+
wasCalledFuture.complete(null);
212251
});
213252
} catch (Throwable e) {
214253
fail("async threw instead of using callback");
215254
}
216255

256+
await(wasCalledFuture, "Callback should have been called");
257+
217258
// The following code can be used to debug variations:
218259
// System.out.println("===VARIATION START");
219260
// System.out.println("sync: " + expectedEvents);
220-
// System.out.println("callback called?: " + wasCalled.get());
261+
// System.out.println("callback called?: " + wasCalledFuture.isDone());
221262
// System.out.println("value -- sync: " + expectedValue + " -- async: " + actualValue.get());
222263
// System.out.println("excep -- sync: " + expectedException + " -- async: " + actualException.get());
223264
// System.out.println("exception mode: " + (isTestingAbruptCompletion
@@ -229,7 +270,7 @@ private <T> void assertBehavesSame(final Supplier<T> sync, final Runnable betwee
229270
throw (AssertionFailedError) actualException.get();
230271
}
231272

232-
assertTrue(wasCalled.get(), "callback should have been called");
273+
assertTrue(wasCalledFuture.isDone(), "callback should have been called");
233274
assertEquals(expectedEvents, listener.getEventStrings(), "steps should have matched");
234275
assertEquals(expectedValue, actualValue.get());
235276
assertEquals(expectedException == null, actualException.get() == null,
@@ -242,6 +283,14 @@ private <T> void assertBehavesSame(final Supplier<T> sync, final Runnable betwee
242283
listener.clear();
243284
}
244285

286+
protected <T> T await(final CompletableFuture<T> voidCompletableFuture, final String errorMessage) {
287+
try {
288+
return voidCompletableFuture.get(1, TimeUnit.MINUTES);
289+
} catch (InterruptedException | ExecutionException | TimeoutException e) {
290+
throw new AssertionError(errorMessage);
291+
}
292+
}
293+
245294
/**
246295
* Tracks invocations: allows testing of all variations of a method calls
247296
*/

0 commit comments

Comments
 (0)