diff --git a/include/signalrclient/signalr_client_config.h b/include/signalrclient/signalr_client_config.h index b1e4fed0..37c4f91c 100644 --- a/include/signalrclient/signalr_client_config.h +++ b/include/signalrclient/signalr_client_config.h @@ -47,6 +47,10 @@ namespace signalr SIGNALRCLIENT_API const std::shared_ptr& __cdecl get_scheduler() const noexcept; SIGNALRCLIENT_API void set_handshake_timeout(std::chrono::milliseconds); SIGNALRCLIENT_API std::chrono::milliseconds get_handshake_timeout() const noexcept; + SIGNALRCLIENT_API void set_server_timeout(std::chrono::milliseconds); + SIGNALRCLIENT_API std::chrono::milliseconds get_server_timeout() const noexcept; + SIGNALRCLIENT_API void set_keepalive_interval(std::chrono::milliseconds); + SIGNALRCLIENT_API std::chrono::milliseconds get_keepalive_interval() const noexcept; private: #ifdef USE_CPPRESTSDK @@ -56,5 +60,7 @@ namespace signalr std::map m_http_headers; std::shared_ptr m_scheduler; std::chrono::milliseconds m_handshake_timeout; + std::chrono::milliseconds m_server_timeout; + std::chrono::milliseconds m_keepalive_interval; }; } diff --git a/src/signalrclient/hub_connection_impl.cpp b/src/signalrclient/hub_connection_impl.cpp index 8f705c4d..ea9074c6 100644 --- a/src/signalrclient/hub_connection_impl.cpp +++ b/src/signalrclient/hub_connection_impl.cpp @@ -185,6 +185,10 @@ namespace signalr callback(exception); }, exception); } + else + { + connection->start_keepalive(); + } }; auto handshake_request = handshake::write_handshake(connection->m_protocol); @@ -348,6 +352,7 @@ namespace signalr } } + reset_server_timeout(); auto messages = m_protocol->parse_messages(response); for (const auto& val : messages) @@ -385,7 +390,10 @@ namespace signalr // Sent to server only, should not be received by client throw std::runtime_error("Received unexpected message type 'CancelInvocation'."); case message_type::ping: - // TODO + if (m_logger.is_enabled(trace_level::debug)) + { + m_logger.log(trace_level::debug, "ping message received."); + } break; case message_type::close: // TODO @@ -477,6 +485,8 @@ namespace signalr } } }); + + reset_send_ping(); } catch (const std::exception& e) { @@ -510,6 +520,126 @@ namespace signalr m_disconnected = disconnected; } + void hub_connection_impl::reset_send_ping() + { + auto timeMs = (std::chrono::steady_clock::now() + m_signalr_client_config.get_keepalive_interval()).time_since_epoch(); + m_nextActivationSendPing.store(std::chrono::duration_cast(timeMs).count()); + } + + void hub_connection_impl::reset_server_timeout() + { + auto timeMs = (std::chrono::steady_clock::now() + m_signalr_client_config.get_server_timeout()).time_since_epoch(); + m_nextActivationServerTimeout.store(std::chrono::duration_cast(timeMs).count()); + } + + void hub_connection_impl::start_keepalive() + { + if (m_logger.is_enabled(trace_level::debug)) + { + m_logger.log(trace_level::debug, "starting keep alive timer."); + } + + auto send_ping = [](std::shared_ptr connection) + { + if (!connection) + { + return; + } + + if (connection->get_connection_state() != connection_state::connected) + { + return; + } + + try + { + hub_message ping_msg(signalr::message_type::ping); + auto message = connection->m_protocol->write_message(&ping_msg); + + std::weak_ptr weak_connection = connection; + connection->m_connection->send( + message, + connection->m_protocol->transfer_format(), [weak_connection](std::exception_ptr exception) + { + auto connection = weak_connection.lock(); + if (connection) + { + if (exception) + { + if (connection->m_logger.is_enabled(trace_level::warning)) + { + connection->m_logger.log(trace_level::warning, "failed to send ping!"); + } + } + else + { + connection->reset_send_ping(); + } + } + }); + } + catch (const std::exception& e) + { + if (connection->m_logger.is_enabled(trace_level::warning)) + { + connection->m_logger.log(trace_level::warning, std::string("failed to send ping: ").append(e.what())); + } + } + }; + + send_ping(shared_from_this()); + reset_server_timeout(); + + std::weak_ptr weak_connection = shared_from_this(); + timer(m_signalr_client_config.get_scheduler(), + [send_ping, weak_connection](std::chrono::milliseconds) + { + auto connection = weak_connection.lock(); + + if (!connection) + { + return true; + } + + if (connection->get_connection_state() != connection_state::connected) + { + return true; + } + + auto timeNowmSeconds = + std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count(); + + if (timeNowmSeconds > connection->m_nextActivationServerTimeout.load()) + { + if (connection->get_connection_state() == connection_state::connected) + { + auto error_msg = std::string("server timeout (") + .append(std::to_string(connection->m_signalr_client_config.get_server_timeout().count())) + .append(" ms) elapsed without receiving a message from the server."); + if (connection->m_logger.is_enabled(trace_level::warning)) + { + connection->m_logger.log(trace_level::warning, error_msg); + } + + connection->m_connection->stop([](std::exception_ptr) + { + }, std::make_exception_ptr(signalr_exception(error_msg))); + } + } + + if (timeNowmSeconds > connection->m_nextActivationSendPing.load()) + { + if (connection->m_logger.is_enabled(trace_level::debug)) + { + connection->m_logger.log(trace_level::debug, "sending ping to server."); + } + send_ping(connection); + } + + return false; + }); + } + // unnamed namespace makes it invisble outside this translation unit namespace { diff --git a/src/signalrclient/hub_connection_impl.h b/src/signalrclient/hub_connection_impl.h index a52302d7..578d0756 100644 --- a/src/signalrclient/hub_connection_impl.h +++ b/src/signalrclient/hub_connection_impl.h @@ -65,6 +65,9 @@ namespace signalr signalr_client_config m_signalr_client_config; std::unique_ptr m_protocol; + std::atomic m_nextActivationServerTimeout; + std::atomic m_nextActivationSendPing; + std::mutex m_stop_callback_lock; std::vector> m_stop_callbacks; @@ -75,5 +78,10 @@ namespace signalr void invoke_hub_method(const std::string& method_name, const std::vector& arguments, const std::string& callback_id, std::function set_completion, std::function set_exception) noexcept; bool invoke_callback(completion_message* completion); + + void reset_send_ping(); + void reset_server_timeout(); + + void start_keepalive(); }; } diff --git a/src/signalrclient/signalr_client_config.cpp b/src/signalrclient/signalr_client_config.cpp index 32962152..adbc3aec 100644 --- a/src/signalrclient/signalr_client_config.cpp +++ b/src/signalrclient/signalr_client_config.cpp @@ -44,6 +44,8 @@ namespace signalr signalr_client_config::signalr_client_config() : m_handshake_timeout(std::chrono::seconds(15)) + , m_server_timeout(std::chrono::seconds(30)) + , m_keepalive_interval(std::chrono::seconds(15)) { m_scheduler = std::make_shared(); } @@ -92,4 +94,34 @@ namespace signalr { return m_handshake_timeout; } + + void signalr_client_config::set_server_timeout(std::chrono::milliseconds timeout) + { + if (timeout <= std::chrono::seconds(0)) + { + throw std::runtime_error("timeout must be greater than 0."); + } + + m_server_timeout = timeout; + } + + std::chrono::milliseconds signalr_client_config::get_server_timeout() const noexcept + { + return m_server_timeout; + } + + void signalr_client_config::set_keepalive_interval(std::chrono::milliseconds interval) + { + if (interval <= std::chrono::seconds(0)) + { + throw std::runtime_error("interval must be greater than 0."); + } + + m_keepalive_interval = interval; + } + + std::chrono::milliseconds signalr_client_config::get_keepalive_interval() const noexcept + { + return m_keepalive_interval; + } } diff --git a/test/signalrclienttests/hub_connection_tests.cpp b/test/signalrclienttests/hub_connection_tests.cpp index 9e1c945a..d5ef1dc9 100644 --- a/test/signalrclienttests/hub_connection_tests.cpp +++ b/test/signalrclienttests/hub_connection_tests.cpp @@ -1738,9 +1738,9 @@ TEST(config, can_replace_scheduler) mre.get(); - // http_client->send (negotiate), websocket_client->start, handshake timeout timer, websocket_client->send, websocket_client->send, websocket_client->stop + // http_client->send (negotiate), websocket_client->start, handshake timeout timer, websocket_client->send, websocket_client->send, keep alive timer, websocket_client->send ping, websocket_client->stop // handshake timeout timer can trigger more than once if test takes more than 1 second - ASSERT_GE(6, scheduler->schedule_count); + ASSERT_GE(scheduler->schedule_count, 8); } class throw_hub_protocol : public hub_protocol @@ -1814,3 +1814,135 @@ TEST(send, throws_if_protocol_fails) ASSERT_EQ(connection_state::connected, hub_connection->get_connection_state()); } + +TEST(keepalive, sends_ping_messages) +{ + signalr_client_config config; + config.set_keepalive_interval(std::chrono::seconds(1)); + config.set_server_timeout(std::chrono::seconds(3)); + auto ping_mre = manual_reset_event(); + auto messages = std::make_shared>(); + auto websocket_client = create_test_websocket_client( + /* send function */ [messages, &ping_mre](const std::string& msg, std::function callback) + { + if (messages->size() < 3) + { + messages->push_back(msg); + } + if (messages->size() == 3) + { + ping_mre.set(); + } + callback(nullptr); + }, + [](const std::string&, std::function callback) { callback(nullptr); }, + [](std::function callback) { callback(nullptr); }, + false); + auto hub_connection = create_hub_connection(websocket_client); + hub_connection.set_client_config(config); + + auto mre = manual_reset_event(); + hub_connection.start([&mre](std::exception_ptr exception) + { + mre.set(exception); + }); + + ASSERT_FALSE(websocket_client->receive_loop_started.wait(5000)); + ASSERT_FALSE(websocket_client->handshake_sent.wait(5000)); + websocket_client->receive_message("{}\x1e"); + + mre.get(); + + ping_mre.get(); + + ASSERT_EQ(3, messages->size()); + ASSERT_EQ("{\"protocol\":\"json\",\"version\":1}\x1e", (*messages)[0]); + ASSERT_EQ("{\"type\":6}\x1e", (*messages)[1]); + ASSERT_EQ("{\"type\":6}\x1e", (*messages)[2]); + ASSERT_EQ(connection_state::connected, hub_connection.get_connection_state()); +} + +TEST(keepalive, server_timeout_on_no_ping_from_server) +{ + signalr_client_config config; + config.set_keepalive_interval(std::chrono::seconds(1)); + config.set_server_timeout(std::chrono::seconds(1)); + auto websocket_client = create_test_websocket_client(); + auto hub_connection = create_hub_connection(websocket_client); + hub_connection.set_client_config(config); + + auto disconnected_called = false; + + auto disconnect_mre = manual_reset_event(); + hub_connection.set_disconnected([&disconnected_called, &disconnect_mre](std::exception_ptr ex) + { + disconnect_mre.set(ex); + }); + + auto mre = manual_reset_event(); + hub_connection.start([&mre](std::exception_ptr exception) + { + mre.set(exception); + }); + + ASSERT_FALSE(websocket_client->receive_loop_started.wait(5000)); + ASSERT_FALSE(websocket_client->handshake_sent.wait(5000)); + websocket_client->receive_message("{}\x1e"); + + mre.get(); + + try + { + disconnect_mre.get(); + ASSERT_TRUE(false); + } + catch (const std::exception& ex) + { + ASSERT_STREQ("server timeout (1000 ms) elapsed without receiving a message from the server.", ex.what()); + } + ASSERT_EQ(connection_state::disconnected, hub_connection.get_connection_state()); +} + +TEST(keepalive, resets_server_timeout_timer_on_any_message_from_server) +{ + signalr_client_config config; + config.set_keepalive_interval(std::chrono::seconds(1)); + config.set_server_timeout(std::chrono::seconds(1)); + auto websocket_client = create_test_websocket_client(); + auto hub_connection = create_hub_connection(websocket_client); + hub_connection.set_client_config(config); + + auto disconnect_mre = manual_reset_event(); + hub_connection.set_disconnected([&disconnect_mre](std::exception_ptr ex) + { + disconnect_mre.set(ex); + }); + + auto mre = manual_reset_event(); + hub_connection.start([&mre](std::exception_ptr exception) + { + mre.set(exception); + }); + + ASSERT_FALSE(websocket_client->receive_loop_started.wait(5000)); + ASSERT_FALSE(websocket_client->handshake_sent.wait(5000)); + websocket_client->receive_message("{}\x1e"); + + mre.get(); + + std::this_thread::sleep_for(config.get_server_timeout() - std::chrono::milliseconds(500)); + websocket_client->receive_message("{\"type\":6}\x1e"); + std::this_thread::sleep_for(std::chrono::seconds(1)); + ASSERT_EQ(connection_state::connected, hub_connection.get_connection_state()); + + try + { + disconnect_mre.get(); + ASSERT_TRUE(false); + } + catch (const std::exception& ex) + { + ASSERT_STREQ("server timeout (1000 ms) elapsed without receiving a message from the server.", ex.what()); + } + ASSERT_EQ(connection_state::disconnected, hub_connection.get_connection_state()); +} diff --git a/test/signalrclienttests/test_websocket_client.cpp b/test/signalrclienttests/test_websocket_client.cpp index 8c99a897..7434f523 100644 --- a/test/signalrclienttests/test_websocket_client.cpp +++ b/test/signalrclienttests/test_websocket_client.cpp @@ -10,12 +10,14 @@ std::shared_ptr create_test_websocket_client( std::function)> send_function, std::function)> connect_function, - std::function)> close_function) + std::function)> close_function, + bool ignore_pings) { auto websocket_client = std::make_shared(); websocket_client->set_send_function(send_function); websocket_client->set_connect_function(connect_function); websocket_client->set_close_function(close_function); + websocket_client->ignore_pings = ignore_pings; return websocket_client; } @@ -24,7 +26,7 @@ test_websocket_client::test_websocket_client() : m_connect_function(std::make_shared)>>([](const std::string&, std::function callback) { callback(nullptr); })), m_send_function(std::make_shared)>>([](const std::string msg, std::function callback) { callback(nullptr); })), m_close_function(std::make_shared)>>([](std::function callback) { callback(nullptr); })), - m_receive_message_event(), m_receive_message(), m_stopped(true), receive_count(0) + m_receive_message_event(), m_receive_message(), m_stopped(true), receive_count(0), ignore_pings(true) { m_receive_loop_not_running.cancel(); } @@ -107,8 +109,14 @@ void test_websocket_client::send(const std::string& payload, signalr::transfer_f { handshake_sent.cancel(); auto local_copy = m_send_function; - m_scheduler->schedule([payload, callback, local_copy]() + auto l_ignore_pings = ignore_pings; + m_scheduler->schedule([payload, callback, local_copy, l_ignore_pings]() { + if (l_ignore_pings && payload.find("\"type\":6") != -1) + { + callback(nullptr); + return; + } (*local_copy)(payload, callback); }); } diff --git a/test/signalrclienttests/test_websocket_client.h b/test/signalrclienttests/test_websocket_client.h index d232d6da..f11c07a7 100644 --- a/test/signalrclienttests/test_websocket_client.h +++ b/test/signalrclienttests/test_websocket_client.h @@ -41,6 +41,7 @@ class test_websocket_client : public websocket_client cancellation_token_source receive_loop_started; cancellation_token_source handshake_sent; int receive_count; + bool ignore_pings; private: std::shared_ptr)>> m_connect_function; @@ -63,4 +64,4 @@ class test_websocket_client : public websocket_client std::shared_ptr create_test_websocket_client( std::function)> send_function = [](const std::string&, std::function callback) { callback(nullptr); }, std::function)> connect_function = [](const std::string&, std::function callback) { callback(nullptr); }, - std::function)> close_function = [](std::function callback) { callback(nullptr); }); \ No newline at end of file + std::function)> close_function = [](std::function callback) { callback(nullptr); }, bool ignore_pings = true); \ No newline at end of file