diff --git a/README.md b/README.md index 5f2543b..f0ac0da 100644 --- a/README.md +++ b/README.md @@ -79,9 +79,10 @@ The WebsocketServer can be initialized with the below parameters. | `send_message_to_all()` | Sends a `message` to **all** connected clients. The message is a simple string. | message | None | | `disconnect_clients_gracefully()` | Disconnect all connected clients by sending a websocket CLOSE handshake. | Optional: status, reason | None | | `disconnect_clients_abruptly()` | Disconnect all connected clients. Clients won't be aware until they try to send some data. | None | None | -| `shutdown_gracefully()` | Disconnect clients with a CLOSE handshake and shutdown server. | Optional: status, reason | None | -| `shutdown_abruptly()` | Disconnect clients and shutdown server with no handshake. | None | None | - +| `shutdown_gracefully()` | Disconnect clients with a CLOSE handshake and shutdown server. | Optional: status, reason | None | +| `shutdown_abruptly()` | Disconnect clients and shutdown server with no handshake. | None | None | +| `deny_new_connections()` | Close connection for new clients. | Optional: status, reason | None | +| `allow_new_connections()` | Allows back connection for new clients. | | None | ### Callback functions diff --git a/releases.txt b/releases.txt index 8b15c9b..2fda06c 100644 --- a/releases.txt +++ b/releases.txt @@ -27,3 +27,8 @@ 0.6.3 - Remove deprecation warnings + +0.6.4 +- Add deny_new_connections & allow_new_connections +- Fix disconnect_clients_gracefully to now take params +- Fix shutdown_gracefully unused param diff --git a/server.py b/server.py index f0587c6..210c884 100644 --- a/server.py +++ b/server.py @@ -19,7 +19,7 @@ def message_received(client, server, message): PORT=9001 -server = WebsocketServer(PORT) +server = WebsocketServer(port = PORT) server.set_fn_new_client(new_client) server.set_fn_client_left(client_left) server.set_fn_message_received(message_received) diff --git a/setup.py b/setup.py index baf1184..684b88d 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ from distutils.command.install import install -VERSION = '0.6.3' +VERSION = '0.6.4' def get_tag_version(): diff --git a/tests/test_server.py b/tests/test_server.py index 801e555..d03ce08 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -141,3 +141,19 @@ def test_disconnect_clients_gracefully(self, session): for i in range(3): client.send("test") sleep(0.2) + + def test_deny_new_connections(self, threaded_server): + url = "ws://{}:{}".format(*threaded_server.server_address) + server = threaded_server + server.deny_new_connections(status=1013, reason=b"Please try re-connecting later") + + conn = websocket.create_connection(url) + try: + conn.send("test") + except websocket.WebSocketProtocolException as e: + assert 'Invalid close opcode' in e.args[0] + assert not server.clients + + server.allow_new_connections() + conn = websocket.create_connection(url) + conn.send("test") diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index 6d78613..c954c34 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -80,14 +80,20 @@ def send_message(self, client, msg): def send_message_to_all(self, msg): self._multicast(msg) + def deny_new_connections(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + self._deny_new_connections(status, reason) + + def allow_new_connections(self): + self._allow_new_connections() + def shutdown_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): - self._shutdown_gracefully(status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON) + self._shutdown_gracefully(status, reason) def shutdown_abruptly(self): self._shutdown_abruptly() - def disconnect_clients_gracefully(self): - self._disconnect_clients_gracefully() + def disconnect_clients_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + self._disconnect_clients_gracefully(status, reason) def disconnect_clients_abruptly(self): self._disconnect_clients_abruptly() @@ -131,6 +137,8 @@ def __init__(self, host='127.0.0.1', port=0, loglevel=logging.WARNING, key=None, self.id_counter = 0 self.thread = None + self._deny_clients = False + def _run_forever(self, threaded): cls_name = self.__class__.__name__ try: @@ -161,6 +169,13 @@ def _pong_received_(self, handler, msg): pass def _new_client_(self, handler): + if self._deny_clients: + status = self._deny_clients["status"] + reason = self._deny_clients["reason"] + handler.send_close(status, reason) + self._terminate_client_handler(handler) + return + self.id_counter += 1 client = { 'id': self.id_counter, @@ -188,14 +203,17 @@ def handler_to_client(self, handler): if client['handler'] == handler: return client + def _terminate_client_handler(self, handler): + handler.keep_alive = False + handler.finish() + handler.connection.close() + def _terminate_client_handlers(self): """ Ensures request handler for each client is terminated correctly """ for client in self.clients: - client["handler"].keep_alive = False - client["handler"].finish() - client["handler"].connection.close() + self._terminate_client_handler(client["handler"]) def _shutdown_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): """ @@ -220,7 +238,7 @@ def _disconnect_clients_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFA Terminate clients gracefully without shutting down the server """ for client in self.clients: - client["handler"].send_close(CLOSE_STATUS_NORMAL, reason) + client["handler"].send_close(status, reason) self._terminate_client_handlers() def _disconnect_clients_abruptly(self): @@ -229,6 +247,15 @@ def _disconnect_clients_abruptly(self): """ self._terminate_client_handlers() + def _deny_new_connections(self, status, reason): + self._deny_clients = { + "status": status, + "reason": reason, + } + + def _allow_new_connections(self): + self._deny_clients = False + class WebSocketHandler(StreamRequestHandler): @@ -238,9 +265,13 @@ def __init__(self, socket, addr, server): self._send_lock = threading.Lock() if server.key and server.cert: try: - socket = ssl.wrap_socket(socket, server_side=True, certfile=server.cert, keyfile=server.key) - except: # Not sure which exception it throws if the key/cert isn't found - logger.warning("SSL not available (are the paths {} and {} correct for the key and cert?)".format(server.key, server.cert)) + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(certfile=server.cert, keyfile=server.key) + socket = ssl_context.wrap_socket(socket, server_side=True) + except FileNotFoundError: + logger.warning("SSL key or certificate file not found. Please check the paths for the key and cert.") + except ssl.SSLError as e: + logger.warning(f"SSL error occurred: {e}") StreamRequestHandler.__init__(self, socket, addr, server) def setup(self):