Skip to content

Commit 7d6f505

Browse files
committed
Protect against AsyncHandlerExtensions calls crash, close AsyncHttpClient#1384
1 parent 4321163 commit 7d6f505

File tree

5 files changed

+165
-62
lines changed

5 files changed

+165
-62
lines changed

client/src/main/java/org/asynchttpclient/netty/channel/ChannelManager.java

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
*/
1414
package org.asynchttpclient.netty.channel;
1515

16+
import static org.asynchttpclient.handler.AsyncHandlerExtensionsUtils.toAsyncHandlerExtensions;
17+
1618
import io.netty.bootstrap.Bootstrap;
1719
import io.netty.buffer.ByteBufAllocator;
1820
import io.netty.channel.Channel;
@@ -268,8 +270,16 @@ public final void tryToOfferChannelToPool(Channel channel, AsyncHandler<?> async
268270
if (channel.isActive() && keepAlive) {
269271
LOGGER.debug("Adding key: {} for channel {}", partitionKey, channel);
270272
Channels.setDiscard(channel);
271-
if (asyncHandler instanceof AsyncHandlerExtensions)
272-
AsyncHandlerExtensions.class.cast(asyncHandler).onConnectionOffer(channel);
273+
274+
final AsyncHandlerExtensions asyncHandlerExtensions = toAsyncHandlerExtensions(asyncHandler);
275+
if (asyncHandlerExtensions != null) {
276+
try {
277+
asyncHandlerExtensions.onConnectionOffer(channel);
278+
} catch (Exception e) {
279+
LOGGER.error("onConnectionOffer crashed", e);
280+
}
281+
}
282+
273283
if (!channelPool.offer(channel, partitionKey)) {
274284
// rejected by pool
275285
closeChannel(channel);
@@ -416,26 +426,15 @@ public EventLoopGroup getEventLoopGroup() {
416426
}
417427

418428
public ClientStats getClientStats() {
419-
Map<String, Long> totalConnectionsPerHost = openChannels
420-
.stream()
421-
.map(Channel::remoteAddress)
422-
.filter(a -> a.getClass() == InetSocketAddress.class)
423-
.map(a -> (InetSocketAddress) a)
424-
.map(InetSocketAddress::getHostName)
425-
.collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
429+
Map<String, Long> totalConnectionsPerHost = openChannels.stream().map(Channel::remoteAddress).filter(a -> a.getClass() == InetSocketAddress.class)
430+
.map(a -> (InetSocketAddress) a).map(InetSocketAddress::getHostName).collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
426431
Map<String, Long> idleConnectionsPerHost = channelPool.getIdleChannelCountPerHost();
427-
Map<String, HostStats> statsPerHost = totalConnectionsPerHost
428-
.entrySet()
429-
.stream()
430-
.collect(Collectors.toMap(
431-
Entry::getKey,
432-
entry -> {
433-
final long totalConnectionCount = entry.getValue();
434-
final long idleConnectionCount = idleConnectionsPerHost.getOrDefault(entry.getKey(), 0L);
435-
final long activeConnectionCount = totalConnectionCount - idleConnectionCount;
436-
return new HostStats(activeConnectionCount, idleConnectionCount);
437-
}
438-
));
432+
Map<String, HostStats> statsPerHost = totalConnectionsPerHost.entrySet().stream().collect(Collectors.toMap(Entry::getKey, entry -> {
433+
final long totalConnectionCount = entry.getValue();
434+
final long idleConnectionCount = idleConnectionsPerHost.getOrDefault(entry.getKey(), 0L);
435+
final long activeConnectionCount = totalConnectionCount - idleConnectionCount;
436+
return new HostStats(activeConnectionCount, idleConnectionCount);
437+
}));
439438
return new ClientStats(statsPerHost);
440439
}
441440

client/src/main/java/org/asynchttpclient/netty/channel/NettyConnectListener.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,28 +126,48 @@ public void operationComplete(Future<? super Void> future) throws Exception {
126126
try {
127127
sslHandler = channelManager.addSslHandler(channel.pipeline(), uri, request.getVirtualHost());
128128
} catch (Exception sslError) {
129-
NettyConnectListener.this.onFailure(channel, sslError);
129+
onFailure(channel, sslError);
130130
return;
131131
}
132132

133133
final AsyncHandlerExtensions asyncHandlerExtensions = toAsyncHandlerExtensions(future.getAsyncHandler());
134134

135-
if (asyncHandlerExtensions != null)
136-
asyncHandlerExtensions.onTlsHandshakeAttempt();
135+
if (asyncHandlerExtensions != null) {
136+
try {
137+
asyncHandlerExtensions.onTlsHandshakeAttempt();
138+
} catch (Exception e) {
139+
LOGGER.error("onTlsHandshakeAttempt crashed", e);
140+
onFailure(channel, e);
141+
return;
142+
}
143+
}
137144

138145
sslHandler.handshakeFuture().addListener(new SimpleFutureListener<Channel>() {
139-
140146
@Override
141147
protected void onSuccess(Channel value) throws Exception {
142-
if (asyncHandlerExtensions != null)
143-
asyncHandlerExtensions.onTlsHandshakeSuccess();
148+
if (asyncHandlerExtensions != null) {
149+
try {
150+
asyncHandlerExtensions.onTlsHandshakeSuccess();
151+
} catch (Exception e) {
152+
LOGGER.error("onTlsHandshakeSuccess crashed", e);
153+
NettyConnectListener.this.onFailure(channel, e);
154+
return;
155+
}
156+
}
144157
writeRequest(channel);
145158
}
146159

147160
@Override
148161
protected void onFailure(Throwable cause) throws Exception {
149-
if (asyncHandlerExtensions != null)
150-
asyncHandlerExtensions.onTlsHandshakeFailure(cause);
162+
if (asyncHandlerExtensions != null) {
163+
try {
164+
asyncHandlerExtensions.onTlsHandshakeFailure(cause);
165+
} catch (Exception e) {
166+
LOGGER.error("onTlsHandshakeFailure crashed", e);
167+
NettyConnectListener.this.onFailure(channel, e);
168+
return;
169+
}
170+
}
151171
NettyConnectListener.this.onFailure(channel, cause);
152172
}
153173
});

client/src/main/java/org/asynchttpclient/netty/request/NettyChannelConnector.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,15 @@ private boolean pickNextRemoteAddress() {
5959
public void connect(final Bootstrap bootstrap, final NettyConnectListener<?> connectListener) {
6060
final InetSocketAddress remoteAddress = remoteAddresses.get(i);
6161

62-
if (asyncHandlerExtensions != null)
63-
asyncHandlerExtensions.onTcpConnectAttempt(remoteAddress);
62+
if (asyncHandlerExtensions != null) {
63+
try {
64+
asyncHandlerExtensions.onTcpConnectAttempt(remoteAddress);
65+
} catch (Exception e) {
66+
LOGGER.error("onTcpConnectAttempt crashed", e);
67+
connectListener.onFailure(null, e);
68+
return;
69+
}
70+
}
6471

6572
try {
6673
connect0(bootstrap, connectListener, remoteAddress);
@@ -77,24 +84,37 @@ private void connect0(Bootstrap bootstrap, final NettyConnectListener<?> connect
7784

7885
bootstrap.connect(remoteAddress, localAddress)//
7986
.addListener(new SimpleChannelFutureListener() {
80-
8187
@Override
8288
public void onSuccess(Channel channel) {
8389
if (asyncHandlerExtensions != null) {
84-
asyncHandlerExtensions.onTcpConnectSuccess(remoteAddress, channel);
90+
try {
91+
asyncHandlerExtensions.onTcpConnectSuccess(remoteAddress, channel);
92+
} catch (Exception e) {
93+
LOGGER.error("onTcpConnectSuccess crashed", e);
94+
connectListener.onFailure(channel, e);
95+
return;
96+
}
8597
}
8698
connectListener.onSuccess(channel, remoteAddress);
8799
}
88100

89101
@Override
90102
public void onFailure(Channel channel, Throwable t) {
91-
if (asyncHandlerExtensions != null)
92-
asyncHandlerExtensions.onTcpConnectFailure(remoteAddress, t);
103+
if (asyncHandlerExtensions != null) {
104+
try {
105+
asyncHandlerExtensions.onTcpConnectFailure(remoteAddress, t);
106+
} catch (Exception e) {
107+
LOGGER.error("onTcpConnectFailure crashed", e);
108+
connectListener.onFailure(channel, e);
109+
return;
110+
}
111+
}
93112
boolean retry = pickNextRemoteAddress();
94-
if (retry)
113+
if (retry) {
95114
NettyChannelConnector.this.connect(bootstrap, connectListener);
96-
else
115+
} else {
97116
connectListener.onFailure(channel, t);
117+
}
98118
}
99119
});
100120
}

client/src/main/java/org/asynchttpclient/netty/request/NettyRequestSender.java

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
package org.asynchttpclient.netty.request;
1515

1616
import static io.netty.handler.codec.http.HttpHeaderNames.EXPECT;
17+
import static org.asynchttpclient.handler.AsyncHandlerExtensionsUtils.toAsyncHandlerExtensions;
1718
import static org.asynchttpclient.util.Assertions.assertNotNull;
1819
import static org.asynchttpclient.util.AuthenticatorUtils.*;
1920
import static org.asynchttpclient.util.HttpConstants.Methods.*;
@@ -209,17 +210,25 @@ private <T> NettyResponseFuture<T> newNettyRequestAndResponseFuture(final Reques
209210
}
210211

211212
private Channel getOpenChannel(NettyResponseFuture<?> future, Request request, ProxyServer proxyServer, AsyncHandler<?> asyncHandler) {
212-
213-
if (future != null && future.isReuseChannel() && Channels.isChannelValid(future.channel()))
213+
if (future != null && future.isReuseChannel() && Channels.isChannelValid(future.channel())) {
214214
return future.channel();
215-
else
215+
} else {
216216
return pollPooledChannel(request, proxyServer, asyncHandler);
217+
}
217218
}
218219

219220
private <T> ListenableFuture<T> sendRequestWithOpenChannel(Request request, ProxyServer proxy, NettyResponseFuture<T> future, AsyncHandler<T> asyncHandler, Channel channel) {
220221

221-
if (asyncHandler instanceof AsyncHandlerExtensions)
222-
AsyncHandlerExtensions.class.cast(asyncHandler).onConnectionPooled(channel);
222+
final AsyncHandlerExtensions asyncHandlerExtensions = toAsyncHandlerExtensions(asyncHandler);
223+
if (asyncHandlerExtensions != null) {
224+
try {
225+
asyncHandlerExtensions.onConnectionPooled(channel);
226+
} catch (Exception e) {
227+
LOGGER.error("onConnectionPooled crashed", e);
228+
abort(channel, future, e);
229+
return future;
230+
}
231+
}
223232

224233
TimeoutsHolder timeoutsHolder = scheduleRequestTimeout(future);
225234
timeoutsHolder.initRemoteAddress((InetSocketAddress) channel.remoteAddress());
@@ -291,8 +300,7 @@ private <T> ListenableFuture<T> sendRequestWithNewChannel(//
291300

292301
@Override
293302
protected void onSuccess(List<InetSocketAddress> addresses) {
294-
NettyConnectListener<T> connectListener = new NettyConnectListener<>(
295-
future, NettyRequestSender.this, channelManager, connectionSemaphore, partitionKey);
303+
NettyConnectListener<T> connectListener = new NettyConnectListener<>(future, NettyRequestSender.this, channelManager, connectionSemaphore, partitionKey);
296304
NettyChannelConnector connector = new NettyChannelConnector(request.getLocalAddress(), addresses, asyncHandler, clientState, config);
297305
if (!future.isDone()) {
298306
connector.connect(bootstrap, connectListener);
@@ -338,14 +346,22 @@ public <T> void writeRequest(NettyResponseFuture<T> future, Channel channel) {
338346
return;
339347

340348
try {
341-
if (handler instanceof TransferCompletionHandler)
349+
if (handler instanceof TransferCompletionHandler) {
342350
configureTransferAdapter(handler, httpRequest);
351+
}
343352

344353
boolean writeBody = !future.isDontWriteBodyBecauseExpectContinue() && httpRequest.method() != HttpMethod.CONNECT && nettyRequest.getBody() != null;
345354

346355
if (!future.isHeadersAlreadyWrittenOnContinue()) {
347-
if (handler instanceof AsyncHandlerExtensions) {
348-
AsyncHandlerExtensions.class.cast(handler).onRequestSend(nettyRequest);
356+
final AsyncHandlerExtensions asyncHandlerExtensions = toAsyncHandlerExtensions(handler);
357+
if (asyncHandlerExtensions != null) {
358+
try {
359+
asyncHandlerExtensions.onRequestSend(nettyRequest);
360+
} catch (Exception e) {
361+
LOGGER.error("onRequestSend crashed", e);
362+
abort(channel, future, e);
363+
return;
364+
}
349365
}
350366

351367
// if the request has a body, we want to track progress
@@ -365,8 +381,9 @@ public <T> void writeRequest(NettyResponseFuture<T> future, Channel channel) {
365381
nettyRequest.getBody().write(channel, future);
366382

367383
// don't bother scheduling read timeout if channel became invalid
368-
if (Channels.isChannelValid(channel))
384+
if (Channels.isChannelValid(channel)) {
369385
scheduleReadTimeout(future);
386+
}
370387

371388
} catch (Exception e) {
372389
LOGGER.error("Can't write request", e);
@@ -398,8 +415,9 @@ private void scheduleReadTimeout(NettyResponseFuture<?> nettyResponseFuture) {
398415

399416
public void abort(Channel channel, NettyResponseFuture<?> future, Throwable t) {
400417

401-
if (channel != null)
418+
if (channel != null) {
402419
channelManager.closeChannel(channel);
420+
}
403421

404422
if (!future.isDone()) {
405423
future.setChannelState(ChannelState.CLOSED);
@@ -423,15 +441,23 @@ public void handleUnexpectedClosedChannel(Channel channel, NettyResponseFuture<?
423441

424442
public boolean retry(NettyResponseFuture<?> future) {
425443

426-
if (isClosed())
444+
if (isClosed()) {
427445
return false;
446+
}
428447

429448
if (future.isReplayPossible()) {
430449
future.setChannelState(ChannelState.RECONNECTED);
431450

432451
LOGGER.debug("Trying to recover request {}\n", future.getNettyRequest().getHttpRequest());
433-
if (future.getAsyncHandler() instanceof AsyncHandlerExtensions) {
434-
AsyncHandlerExtensions.class.cast(future.getAsyncHandler()).onRetry();
452+
final AsyncHandlerExtensions asyncHandlerExtensions = toAsyncHandlerExtensions(future.getAsyncHandler());
453+
if (asyncHandlerExtensions != null) {
454+
try {
455+
asyncHandlerExtensions.onRetry();
456+
} catch (Exception e) {
457+
LOGGER.error("onRetry crashed", e);
458+
abort(future.channel(), future, e);
459+
return false;
460+
}
435461
}
436462

437463
try {
@@ -478,19 +504,26 @@ private void validateWebSocketRequest(Request request, AsyncHandler<?> asyncHand
478504
Uri uri = request.getUri();
479505
boolean isWs = uri.isWebSocket();
480506
if (asyncHandler instanceof WebSocketUpgradeHandler) {
481-
if (!isWs)
507+
if (!isWs) {
482508
throw new IllegalArgumentException("WebSocketUpgradeHandler but scheme isn't ws or wss: " + uri.getScheme());
483-
else if (!request.getMethod().equals(GET) && !request.getMethod().equals(CONNECT))
509+
} else if (!request.getMethod().equals(GET) && !request.getMethod().equals(CONNECT)) {
484510
throw new IllegalArgumentException("WebSocketUpgradeHandler but method isn't GET or CONNECT: " + request.getMethod());
511+
}
485512
} else if (isWs) {
486513
throw new IllegalArgumentException("No WebSocketUpgradeHandler but scheme is " + uri.getScheme());
487514
}
488515
}
489516

490517
private Channel pollPooledChannel(Request request, ProxyServer proxy, AsyncHandler<?> asyncHandler) {
491518

492-
if (asyncHandler instanceof AsyncHandlerExtensions)
493-
AsyncHandlerExtensions.class.cast(asyncHandler).onConnectionPoolAttempt();
519+
final AsyncHandlerExtensions asyncHandlerExtensions = toAsyncHandlerExtensions(asyncHandler);
520+
if (asyncHandlerExtensions != null) {
521+
try {
522+
asyncHandlerExtensions.onConnectionPoolAttempt();
523+
} catch (Exception e) {
524+
LOGGER.error("onConnectionPoolAttempt crashed", e);
525+
}
526+
}
494527

495528
Uri uri = request.getUri();
496529
String virtualHost = request.getVirtualHost();
@@ -511,8 +544,16 @@ public void replayRequest(final NettyResponseFuture<?> future, FilterContext fc,
511544
future.touch();
512545

513546
LOGGER.debug("\n\nReplaying Request {}\n for Future {}\n", newRequest, future);
514-
if (future.getAsyncHandler() instanceof AsyncHandlerExtensions)
515-
AsyncHandlerExtensions.class.cast(future.getAsyncHandler()).onRetry();
547+
final AsyncHandlerExtensions asyncHandlerExtensions = toAsyncHandlerExtensions(future.getAsyncHandler());
548+
if (asyncHandlerExtensions != null) {
549+
try {
550+
asyncHandlerExtensions.onRetry();
551+
} catch (Exception e) {
552+
LOGGER.error("onRetry crashed", e);
553+
abort(channel, future, e);
554+
return;
555+
}
556+
}
516557

517558
channelManager.drainChannelAndOffer(channel, future);
518559
sendNextRequest(newRequest, future);

0 commit comments

Comments
 (0)