diff --git a/.circleci/config.yml b/.circleci/config.yml index 4eddce6..68aeaf9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,22 +1,73 @@ -# Python CircleCI 2.0 configuration file -# -# Check https://circleci.com/docs/2.0/language-python/ for more details -# -version: 2 +version: 2.1 + +orbs: + python: circleci/python@0.2.1 + jobs: - build: + test: docker: - - image: circleci/python:3 + - image: circleci/python:3.7 steps: - checkout - run: name: install dependencies command: | - sudo pip install tox + pip install -r requirements.txt - run: name: run tests command: | - tox + pytest - store_artifacts: path: test-reports destination: test-reports + deploy: + executor: python/default + steps: + - checkout + - python/load-cache + - run: + name: install twine + command: | + sudo pip install twine + - run: + name: verify git tag vs. version + command: | + python setup.py verify + - run: + name: create packages + command: | + python setup.py sdist + python setup.py bdist_wheel + - run: + name: setup pypi credentials + command: | + echo -e "[pypi]" >> ~/.pypirc + echo -e "username = $PYPI_USERNAME" >> ~/.pypirc + echo -e "password = $PYPI_PASSWORD" >> ~/.pypirc + - run: + name: upload to cheeseshop if a tag found + command: | + tag=`git tag --points-at HEAD` + if [ $tag ]; then + echo Uploading + python -m twine upload dist/* + fi + +workflows: + test: + jobs: + - test: + filters: + branches: + only: development + new_release: + jobs: + - test: + filters: + branches: + only: master + tags: + only: /v[0-9]+(\.[0-9]+)*/ + - deploy: + requires: + - test diff --git a/README.md b/README.md index 0ed6676..f0ac0da 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,24 @@ Websocket Server ======================= -[![CircleCI](https://circleci.com/gh/Pithikos/python-websocket-server/tree/master.svg?style=svg)](https://circleci.com/gh/Pithikos/python-websocket-server/tree/master) +[![CircleCI](https://circleci.com/gh/Pithikos/python-websocket-server/tree/master.svg?style=svg)](https://circleci.com/gh/Pithikos/python-websocket-server/tree/master) [![PyPI version](https://badge.fury.io/py/websocket-server.svg)](https://badge.fury.io/py/websocket-server) A minimal Websockets Server in Python with no external dependencies. - * Python2 and Python3 support + * Python3.6+ * Clean simple API * Multiple clients * No dependencies -Notice that this implementation does not support the more advanced features -like SSL etc. The project is focused mainly on making it easy to run a -websocket server for prototyping, testing or for making a GUI for your application. +Notice this project is focused mainly on making it easy to run a websocket server for prototyping, testing or for making a GUI for your application. Thus not all possible features of Websockets are supported. Installation ======================= -You can use the project in three ways. +Install with pip - 1. Copy/paste the *websocket_server.py* file in your project and use it directly - 2. `pip install git+https://github.com/Pithikos/python-websocket-server` (latest code) - 3. `pip install websocket-server` (might not be up-to-date) + pip install websocket-server For coding details have a look at the [*server.py*](https://github.com/Pithikos/python-websocket-server/blob/master/server.py) example and the [API](https://github.com/Pithikos/python-websocket-server#api). @@ -41,7 +37,7 @@ Testing Run all tests - tox + pytest API @@ -59,6 +55,10 @@ The WebsocketServer can be initialized with the below parameters. *`loglevel`* - logging level to print. By default WARNING is used. You can use `logging.DEBUG` or `logging.INFO` for more verbose output. +*`key`* - If using SSL, this is the path to the key. + +*`cert`* - If using SSL, this is the path to the certificate. + ### Properties @@ -71,11 +71,18 @@ The WebsocketServer can be initialized with the below parameters. | Method | Description | Takes | Gives | |-----------------------------|---------------------------------------------------------------------------------------|-----------------|-------| +| `run_forever()` | Runs server until shutdown_gracefully or shutdown_abruptly are called. | threaded: run server on its own thread if True | None | | `set_fn_new_client()` | Sets a callback function that will be called for every new `client` connecting to us | function | None | | `set_fn_client_left()` | Sets a callback function that will be called for every `client` disconnecting from us | function | None | | `set_fn_message_received()` | Sets a callback function that will be called when a `client` sends a message | function | None | | `send_message()` | Sends a `message` to a specific `client`. The message is a simple string. | client, message | None | | `send_message_to_all()` | Sends a `message` to **all** connected clients. The message is a simple string. | message | None | +| `disconnect_clients_gracefully()` | Disconnect all connected clients by sending a websocket CLOSE handshake. | Optional: status, reason | None | +| `disconnect_clients_abruptly()` | Disconnect all connected clients. Clients won't be aware until they try to send some data. | None | None | +| `shutdown_gracefully()` | Disconnect clients with a CLOSE handshake and shutdown server. | Optional: status, reason | None | +| `shutdown_abruptly()` | Disconnect clients and shutdown server with no handshake. | None | None | +| `deny_new_connections()` | Close connection for new clients. | Optional: status, reason | None | +| `allow_new_connections()` | Allows back connection for new clients. | | None | ### Callback functions @@ -98,10 +105,22 @@ from websocket_server import WebsocketServer def new_client(client, server): server.send_message_to_all("Hey all, a new client has joined us") -server = WebsocketServer(13254, host='127.0.0.1', loglevel=logging.INFO) +server = WebsocketServer(host='127.0.0.1', port=13254, loglevel=logging.INFO) +server.set_fn_new_client(new_client) +server.run_forever() +```` +Example (SSL): +````py +import logging +from websocket_server import WebsocketServer + +def new_client(client, server): + server.send_message_to_all("Hey all, a new client has joined us") + +server = WebsocketServer(host='127.0.0.1', port=13254, loglevel=logging.INFO, key="key.pem", cert="cert.pem") server.set_fn_new_client(new_client) server.run_forever() -```` +```` ## Client diff --git a/docs/release-workflow.md b/docs/release-workflow.md new file mode 100644 index 0000000..316e5d8 --- /dev/null +++ b/docs/release-workflow.md @@ -0,0 +1,16 @@ +Release notes +------------- + +Releases are marked on master branch with tags. The upload to pypi is automated as long as a merge +from development comes with a tag. + +General flow + + 1. Get in dev branch + 2. Update VERSION in setup.py and releases.txt file + 3. Make a commit + 4. Merge development into master (`git merge --no-ff development`) + 4. Add corresponding version as a new tag (`git tag `) e.g. git tag v0.3.0 + 5. Push everything (`git push --tags && git push`) + +- diff --git a/releases.txt b/releases.txt new file mode 100644 index 0000000..2fda06c --- /dev/null +++ b/releases.txt @@ -0,0 +1,34 @@ +0.4 +- Python 2 and 3 support + +0.5.1 +- SSL support +- Drop Python 2 support + +0.5.4 +- Add API for shutting down server (abruptly & gracefully) + +0.5.5 +- Allow running run_forever threaded +- Fix shutting down of a server without connected clients + +0.5.6 +- Support from Python3.6+ + +0.6.0 +- Change order of params 'host' and 'port' +- Add host attribute to server + +0.6.1 +- Sending data is now thread-safe + +0.6.2 +- Add API for disconnecting clients + +0.6.3 +- Remove deprecation warnings + +0.6.4 +- Add deny_new_connections & allow_new_connections +- Fix disconnect_clients_gracefully to now take params +- Fix shutdown_gracefully unused param diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5075893 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +# Dev/test/deploy +IPython +pytest +websocket-client>=1.1.1 +twine diff --git a/server.py b/server.py index f0587c6..210c884 100644 --- a/server.py +++ b/server.py @@ -19,7 +19,7 @@ def message_received(client, server, message): PORT=9001 -server = WebsocketServer(PORT) +server = WebsocketServer(port = PORT) server.set_fn_new_client(new_client) server.set_fn_client_left(client_left) server.set_fn_message_received(message_received) diff --git a/setup.py b/setup.py index b495e1b..684b88d 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,46 @@ -from setuptools import setup, find_packages +import os +import sys +import re +import subprocess +import shlex + +try: + from setuptools import setup, find_packages + from setuptools.command.install import install +except ImportError: + from distutils.core import setup, find_packages + from distutils.command.install import install + + +VERSION = '0.6.4' + + +def get_tag_version(): + cmd = 'git tag --points-at HEAD' + versions = subprocess.check_output(shlex.split(cmd)).splitlines() + if not versions: + return None + if len(versions) != 1: + sys.exit(f"Trying to get tag via git: Expected excactly one tag, got {len(versions)}") + version = versions[0].decode() + if re.match('^v[0-9]', version): + version = version[1:] + return version + + +class VerifyVersionCommand(install): + """ Custom command to verify that the git tag matches our version """ + description = 'verify that the git tag matches our version' + + def run(self): + tag_version = get_tag_version() + if tag_version and tag_version != VERSION: + sys.exit(f"Git tag: {tag} does not match the version of this app: {VERSION}") + setup( name='websocket_server', - version='0.4', + version=VERSION, packages=find_packages("."), url='https://github.com/Pithikos/python-websocket-server', license='MIT', @@ -12,4 +50,8 @@ ], description='A simple fully working websocket-server in Python with no external dependencies', platforms='any', + cmdclass={ + 'verify': VerifyVersionCommand, + }, + python_requires=">=3.6", ) diff --git a/tests/_bootstrap_.py b/tests/_bootstrap_.py deleted file mode 100644 index b8a7949..0000000 --- a/tests/_bootstrap_.py +++ /dev/null @@ -1,6 +0,0 @@ -# Add path to source code -import sys, os -if os.getcwd().endswith('tests'): - sys.path.insert(0, '..') -elif os.getcwd().endswith('websocket-server'): - sys.path.insert(0, '.') diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d8ddf26 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,92 @@ +import logging +from time import sleep +from threading import Thread + +import pytest +import websocket # websocket-client + +# Add path to source code +import sys, os +if os.getcwd().endswith('tests'): + sys.path.insert(0, '..') +elif os.path.exists('websocket_server'): + sys.path.insert(0, '.') +from websocket_server import WebsocketServer + + +class TestClient(): + def __init__(self, port, threaded=True): + self.received_messages = [] + self.closes = [] + self.opens = [] + self.errors = [] + + websocket.enableTrace(True) + self.ws = websocket.WebSocketApp(f"ws://localhost:{port}/", + on_open=self.on_open, + on_message=self.on_message, + on_error=self.on_error, + on_close=self.on_close) + if threaded: + self.thread = Thread(target=self.ws.run_forever) + self.thread.daemon = True + self.thread.start() + else: + self.ws.run_forever() + + def on_message(self, ws, message): + self.received_messages.append(message) + print(f"TestClient: on_message: {message}") + + def on_error(self, ws, error): + self.errors.append(error) + print(f"TestClient: on_error: {error}") + + def on_close(self, ws, close_status_code, close_msg): + self.closes.append((close_status_code, close_msg)) + print(f"TestClient: on_close: {close_status_code} - {close_msg}") + + def on_open(self, ws): + self.opens.append(ws) + print("TestClient: on_open") + + +class TestServer(WebsocketServer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.received_messages = [] + self.set_fn_message_received(self.handle_received_message) + + def handle_received_message(self, client, server, message): + self.received_messages.append(message) + + +@pytest.fixture(scope='function') +def threaded_server(): + """ Returns the response of a server after""" + server = TestServer(loglevel=logging.DEBUG) + server.run_forever(threaded=True) + yield server + server.server_close() + + +@pytest.fixture +def session(threaded_server): + """ + Gives a simple connection to a server + """ + conn = websocket.create_connection("ws://{}:{}".format(*threaded_server.server_address)) + yield conn, threaded_server + conn.close() + + +@pytest.fixture +def client_session(threaded_server): + """ + Gives a TestClient instance connected to a server + """ + client = TestClient(port=threaded_server.port) + sleep(1) + assert client.ws.sock and client.ws.sock.connected + yield client, threaded_server + client.ws.close() diff --git a/tests/test_handshake.py b/tests/test_handshake.py index dfbcc36..74ace26 100644 --- a/tests/test_handshake.py +++ b/tests/test_handshake.py @@ -1,4 +1,3 @@ -import _bootstrap_ from websocket_server import WebSocketHandler diff --git a/tests/test_message_lengths.py b/tests/test_message_lengths.py new file mode 100644 index 0000000..03f9d96 --- /dev/null +++ b/tests/test_message_lengths.py @@ -0,0 +1,67 @@ +def test_text_message_of_length_1(session): + conn, server = session + server.send_message_to_all('$') + assert conn.recv() == '$' + + +def test_text_message_of_length_125B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqr125' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_126B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrs126' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_127B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrst127' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_208B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvw208' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_1251B(session): + conn, server = session + msg = ('abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqr125'*10)+'1' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_68KB(session): + conn, server = session + msg = '$'+('a'*67993)+'68000'+'^' + assert len(msg) == 68000 + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_1500KB(session): + """ An enormous message (well beyond 65K) """ + conn, server = session + msg = '$'+('a'*1499991)+'1500000'+'^' + assert len(msg) == 1500000 + server.send_message_to_all(msg) + assert conn.recv() == msg diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..d03ce08 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,159 @@ +from time import sleep +import threading + +import websocket +import pytest + + +class TestServerThreadedWithoutClient(): + def test_run_forever(self, threaded_server): + assert threaded_server.thread + assert not isinstance(threaded_server.thread, threading._MainThread) + assert threaded_server.thread.is_alive() + + def test_attributes(self, threaded_server): + tpl = threaded_server.server_address + assert threaded_server.port == tpl[1] + assert threaded_server.host == tpl[0] + + def test_shutdown(self, threaded_server): + assert threaded_server.thread.is_alive() + + # Shutdown de-facto way + # REF: https://docs.python.org/3/library/socketserver.html + # "Tell the serve_forever() loop to stop and + # wait until it does. shutdown() must be called while serve_forever() + # is running in a different thread otherwise it will deadlock." + threaded_server.shutdown() + assert not threaded_server.thread.is_alive() + + def test_shutdown_gracefully_without_clients(self, threaded_server): + assert threaded_server.thread.is_alive() + threaded_server.shutdown_gracefully() + assert not threaded_server.thread.is_alive() + assert threaded_server.socket.fileno() <= 0 + + def test_shutdown_abruptly_without_clients(self, threaded_server): + assert threaded_server.thread.is_alive() + threaded_server.shutdown_abruptly() + assert not threaded_server.thread.is_alive() + assert threaded_server.socket.fileno() <= 0 + + +class TestServerThreadedWithClient(): + def test_send_close(self, client_session): + """ + Ensure client stops receiving data once we send_close (socket is still open) + """ + client, server = client_session + assert client.received_messages == [] + + server.send_message_to_all("test1") + sleep(0.5) + assert client.received_messages == ["test1"] + + # After CLOSE, client should not be receiving any messages + server.clients[-1]["handler"].send_close() + sleep(0.5) + server.send_message_to_all("test2") + sleep(0.5) + assert client.received_messages == ["test1"] + + def test_shutdown_gracefully(self, client_session): + client, server = client_session + assert client.ws.sock and client.ws.sock.connected + assert server.socket.fileno() > 0 + + server.shutdown_gracefully() + sleep(0.5) + + # Ensure all parties disconnected + assert not client.ws.sock + assert server.socket.fileno() == -1 + assert not server.clients + + def test_shutdown_abruptly(self, client_session): + client, server = client_session + assert client.ws.sock and client.ws.sock.connected + assert server.socket.fileno() > 0 + + server.shutdown_abruptly() + sleep(0.5) + + # Ensure server socket died + assert server.socket.fileno() == -1 + + # Ensure client handler terminated + assert server.received_messages == [] + assert client.errors == [] + client.ws.send("1st msg after server shutdown") + sleep(0.5) + + # Note the message is received since the client handler + # will terminate only once it has received the last message + # and break out of the keep_alive loop. Any consecutive messages + # will not be received though. + assert server.received_messages == ["1st msg after server shutdown"] + assert len(client.errors) == 1 + assert isinstance(client.errors[0], websocket._exceptions.WebSocketConnectionClosedException) + + # Try to send 2nd message + with pytest.raises(websocket._exceptions.WebSocketConnectionClosedException): + client.ws.send("2nd msg after server shutdown") + + def test_client_closes_gracefully(self, session): + client, server = session + assert client.connected + assert server.clients + old_client_handler = server.clients[0]["handler"] + client.close() + assert not client.connected + + # Ensure server closed connection. + # We test this by having the server trying to send + # data to the client + assert not server.clients + with pytest.raises(BrokenPipeError): + old_client_handler.connection.send(b"test") + + def test_disconnect_clients_abruptly(self, session): + client, server = session + assert client.connected + assert server.clients + server.disconnect_clients_abruptly() + assert not server.clients + + # Client won't be aware until trying to write more data + with pytest.raises(BrokenPipeError): + for i in range(3): + client.send("test") + sleep(0.2) + + def test_disconnect_clients_gracefully(self, session): + client, server = session + assert client.connected + assert server.clients + server.disconnect_clients_gracefully() + assert not server.clients + + # Client won't be aware until trying to write more data + with pytest.raises(BrokenPipeError): + for i in range(3): + client.send("test") + sleep(0.2) + + def test_deny_new_connections(self, threaded_server): + url = "ws://{}:{}".format(*threaded_server.server_address) + server = threaded_server + server.deny_new_connections(status=1013, reason=b"Please try re-connecting later") + + conn = websocket.create_connection(url) + try: + conn.send("test") + except websocket.WebSocketProtocolException as e: + assert 'Invalid close opcode' in e.args[0] + assert not server.clients + + server.allow_new_connections() + conn = websocket.create_connection(url) + conn.send("test") diff --git a/tests/test_text_messages.py b/tests/test_text_messages.py index 6090f4c..b4f2d0f 100644 --- a/tests/test_text_messages.py +++ b/tests/test_text_messages.py @@ -1,90 +1,86 @@ -# -*- coding: utf-8 -*- -from utils import session, server - - def test_text_message_of_length_1(session): - client, server = session + conn, server = session server.send_message_to_all('$') - assert client.recv() == '$' + assert conn.recv() == '$' def test_text_message_of_length_125B(session): - client, server = session + conn, server = session msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ 'abcdefghijklmnopqr125' server.send_message_to_all(msg) - assert client.recv() == msg + assert conn.recv() == msg def test_text_message_of_length_126B(session): - client, server = session + conn, server = session msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ 'abcdefghijklmnopqrs126' server.send_message_to_all(msg) - assert client.recv() == msg + assert conn.recv() == msg def test_text_message_of_length_127B(session): - client, server = session + conn, server = session msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ 'abcdefghijklmnopqrst127' server.send_message_to_all(msg) - assert client.recv() == msg + assert conn.recv() == msg def test_text_message_of_length_208B(session): - client, server = session + conn, server = session msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvw208' server.send_message_to_all(msg) - assert client.recv() == msg + assert conn.recv() == msg def test_text_message_of_length_1251B(session): - client, server = session + conn, server = session msg = ('abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ 'abcdefghijklmnopqr125'*10)+'1' server.send_message_to_all(msg) - assert client.recv() == msg + assert conn.recv() == msg def test_text_message_of_length_68KB(session): - client, server = session + conn, server = session msg = '$'+('a'*67993)+'68000'+'^' assert len(msg) == 68000 server.send_message_to_all(msg) - assert client.recv() == msg + assert conn.recv() == msg def test_text_message_of_length_1500KB(session): """ An enormous message (well beyond 65K) """ - client, server = session + conn, server = session msg = '$'+('a'*1499991)+'1500000'+'^' assert len(msg) == 1500000 server.send_message_to_all(msg) - assert client.recv() == msg + assert conn.recv() == msg def test_text_message_with_unicode_characters(session): - client, server = session + conn, server = session msg = '$äüö^' server.send_message_to_all(msg) - assert client.recv() == msg + assert conn.recv() == msg def test_text_message_stress_bursts(session): - """ Scenario: server sends multiple different message to the same client + """ Scenario: server sends multiple different message to the same conn at once """ from threading import Thread NUM_THREADS = 100 MESSAGE_LEN = 1000 - client, server = session + conn, server = session messages_received = [] # Threads receing @@ -92,7 +88,7 @@ def test_text_message_stress_bursts(session): for i in range(NUM_THREADS): th = Thread( target=lambda fn: messages_received.append(fn()), - args=(client.recv,) + args=(conn.recv,) ) th.daemon = True threads_receiving.append(th) diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index 0bd9933..0000000 --- a/tests/utils.py +++ /dev/null @@ -1,26 +0,0 @@ -import logging -from threading import Thread - -import pytest -from websocket import create_connection # websocket-client - -import _bootstrap_ -from websocket_server import WebsocketServer - - -@pytest.fixture(scope='function') -def server(): - """ Returns the response of a server after""" - s = WebsocketServer(0, loglevel=logging.DEBUG) - server_thread = Thread(target=s.run_forever) - server_thread.daemon = True - server_thread.start() - yield s - s.server_close() - - -@pytest.fixture -def session(server): - ws = create_connection("ws://{}:{}".format(*server.server_address)) - yield ws, server - ws.close() diff --git a/tox.ini b/tox.ini deleted file mode 100644 index b3025e9..0000000 --- a/tox.ini +++ /dev/null @@ -1,6 +0,0 @@ -[tox] -envlist = py27,py3 -[testenv] -deps=pytest - websocket-client -commands=pytest diff --git a/websocket_server/thread.py b/websocket_server/thread.py new file mode 100644 index 0000000..a474203 --- /dev/null +++ b/websocket_server/thread.py @@ -0,0 +1,38 @@ +import threading + + +class ThreadWithLoggedException(threading.Thread): + """ + Similar to Thread but will log exceptions to passed logger. + + Args: + logger: Logger instance used to log any exception in child thread + + Exception is also reachable via .exception from the main thread. + """ + + DIVIDER = "*"*80 + + def __init__(self, *args, **kwargs): + try: + self.logger = kwargs.pop("logger") + except KeyError: + raise Exception("Missing 'logger' in kwargs") + super().__init__(*args, **kwargs) + self.exception = None + + def run(self): + try: + if self._target is not None: + self._target(*self._args, **self._kwargs) + except Exception as exception: + thread = threading.current_thread() + self.exception = exception + self.logger.exception(f"{self.DIVIDER}\nException in child thread {thread}: {exception}\n{self.DIVIDER}") + finally: + del self._target, self._args, self._kwargs + + +class WebsocketServerThread(ThreadWithLoggedException): + """Dummy wrapper to make debug messages a bit more readable""" + pass diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index 96c658b..c954c34 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -3,16 +3,16 @@ import sys import struct +import ssl from base64 import b64encode from hashlib import sha1 import logging from socket import error as SocketError import errno +import threading +from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler -if sys.version_info[0] < 3: - from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler -else: - from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler +from websocket_server.thread import WebsocketServerThread logger = logging.getLogger(__name__) logging.basicConfig() @@ -47,21 +47,14 @@ OPCODE_PING = 0x9 OPCODE_PONG = 0xA +CLOSE_STATUS_NORMAL = 1000 +DEFAULT_CLOSE_REASON = bytes('', encoding='utf-8') -# -------------------------------- API --------------------------------- class API(): - def run_forever(self): - try: - logger.info("Listening on port %d for clients.." % self.port) - self.serve_forever() - except KeyboardInterrupt: - self.server_close() - logger.info("Server terminated.") - except Exception as e: - logger.error(str(e), exc_info=True) - exit(1) + def run_forever(self, threaded=False): + return self._run_forever(threaded) def new_client(self, client, server): pass @@ -82,13 +75,29 @@ def set_fn_message_received(self, fn): self.message_received = fn def send_message(self, client, msg): - self._unicast_(client, msg) + self._unicast(client, msg) def send_message_to_all(self, msg): - self._multicast_(msg) + self._multicast(msg) + def deny_new_connections(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + self._deny_new_connections(status, reason) + + def allow_new_connections(self): + self._allow_new_connections() + + def shutdown_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + self._shutdown_gracefully(status, reason) + + def shutdown_abruptly(self): + self._shutdown_abruptly() + + def disconnect_clients_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + self._disconnect_clients_gracefully(status, reason) + + def disconnect_clients_abruptly(self): + self._disconnect_clients_abruptly() -# ------------------------- Implementation ----------------------------- class WebsocketServer(ThreadingMixIn, TCPServer, API): """ @@ -115,14 +124,41 @@ class WebsocketServer(ThreadingMixIn, TCPServer, API): allow_reuse_address = True daemon_threads = True # comment to keep threads alive until finished - clients = [] - id_counter = 0 - - def __init__(self, port, host='127.0.0.1', loglevel=logging.WARNING): + def __init__(self, host='127.0.0.1', port=0, loglevel=logging.WARNING, key=None, cert=None): logger.setLevel(loglevel) TCPServer.__init__(self, (host, port), WebSocketHandler) + self.host = host self.port = self.socket.getsockname()[1] + self.key = key + self.cert = cert + + self.clients = [] + self.id_counter = 0 + self.thread = None + + self._deny_clients = False + + def _run_forever(self, threaded): + cls_name = self.__class__.__name__ + try: + logger.info("Listening on port %d for clients.." % self.port) + if threaded: + self.daemon = True + self.thread = WebsocketServerThread(target=super().serve_forever, daemon=True, logger=logger) + logger.info(f"Starting {cls_name} on thread {self.thread.getName()}.") + self.thread.start() + else: + self.thread = threading.current_thread() + logger.info(f"Starting {cls_name} on main thread.") + super().serve_forever() + except KeyboardInterrupt: + self.server_close() + logger.info("Server terminated.") + except Exception as e: + logger.error(str(e), exc_info=True) + sys.exit(1) + def _message_received_(self, handler, msg): self.message_received(self.handler_to_client(handler), self, msg) @@ -133,6 +169,13 @@ def _pong_received_(self, handler, msg): pass def _new_client_(self, handler): + if self._deny_clients: + status = self._deny_clients["status"] + reason = self._deny_clients["reason"] + handler.send_close(status, reason) + self._terminate_client_handler(handler) + return + self.id_counter += 1 client = { 'id': self.id_counter, @@ -148,23 +191,87 @@ def _client_left_(self, handler): if client in self.clients: self.clients.remove(client) - def _unicast_(self, to_client, msg): - to_client['handler'].send_message(msg) + def _unicast(self, receiver_client, msg): + receiver_client['handler'].send_message(msg) - def _multicast_(self, msg): + def _multicast(self, msg): for client in self.clients: - self._unicast_(client, msg) + self._unicast(client, msg) def handler_to_client(self, handler): for client in self.clients: if client['handler'] == handler: return client + def _terminate_client_handler(self, handler): + handler.keep_alive = False + handler.finish() + handler.connection.close() + + def _terminate_client_handlers(self): + """ + Ensures request handler for each client is terminated correctly + """ + for client in self.clients: + self._terminate_client_handler(client["handler"]) + + def _shutdown_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + """ + Send a CLOSE handshake to all connected clients before terminating server + """ + self.keep_alive = False + self._disconnect_clients_gracefully(status, reason) + self.server_close() + self.shutdown() + + def _shutdown_abruptly(self): + """ + Terminate server without sending a CLOSE handshake + """ + self.keep_alive = False + self._disconnect_clients_abruptly() + self.server_close() + self.shutdown() + + def _disconnect_clients_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + """ + Terminate clients gracefully without shutting down the server + """ + for client in self.clients: + client["handler"].send_close(status, reason) + self._terminate_client_handlers() + + def _disconnect_clients_abruptly(self): + """ + Terminate clients abruptly (no CLOSE handshake) without shutting down the server + """ + self._terminate_client_handlers() + + def _deny_new_connections(self, status, reason): + self._deny_clients = { + "status": status, + "reason": reason, + } + + def _allow_new_connections(self): + self._deny_clients = False + class WebSocketHandler(StreamRequestHandler): def __init__(self, socket, addr, server): self.server = server + assert not hasattr(self, "_send_lock"), "_send_lock already exists" + self._send_lock = threading.Lock() + if server.key and server.cert: + try: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(certfile=server.cert, keyfile=server.key) + socket = ssl_context.wrap_socket(socket, server_side=True) + except FileNotFoundError: + logger.warning("SSL key or certificate file not found. Please check the paths for the key and cert.") + except ssl.SSLError as e: + logger.warning(f"SSL error occurred: {e}") StreamRequestHandler.__init__(self, socket, addr, server) def setup(self): @@ -181,12 +288,7 @@ def handle(self): self.read_next_message() def read_bytes(self, num): - # python3 gives ordinal of byte directly - bytes = self.rfile.read(num) - if sys.version_info[0] < 3: - return map(ord, bytes) - else: - return bytes + return self.rfile.read(num) def read_next_message(self): try: @@ -210,14 +312,14 @@ def read_next_message(self): self.keep_alive = 0 return if not masked: - logger.warn("Client must always be masked.") + logger.warning("Client must always be masked.") self.keep_alive = 0 return if opcode == OPCODE_CONTINUATION: - logger.warn("Continuation frames are not supported.") + logger.warning("Continuation frames are not supported.") return elif opcode == OPCODE_BINARY: - logger.warn("Binary frames are not supported.") + logger.warning("Binary frames are not supported.") return elif opcode == OPCODE_TEXT: opcode_handler = self.server._message_received_ @@ -226,7 +328,7 @@ def read_next_message(self): elif opcode == OPCODE_PONG: opcode_handler = self.server._pong_received_ else: - logger.warn("Unknown opcode %#x." % opcode) + logger.warning("Unknown opcode %#x." % opcode) self.keep_alive = 0 return @@ -248,6 +350,28 @@ def send_message(self, message): def send_pong(self, message): self.send_text(message, OPCODE_PONG) + def send_close(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + """ + Send CLOSE to client + + Args: + status: Status as defined in https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 + reason: Text with reason of closing the connection + """ + if status < CLOSE_STATUS_NORMAL or status > 1015: + raise Exception(f"CLOSE status must be between 1000 and 1015, got {status}") + + header = bytearray() + payload = struct.pack('!H', status) + reason + payload_length = len(payload) + assert payload_length <= 125, "We only support short closing reasons at the moment" + + # Send CLOSE with status & reason + header.append(FIN | OPCODE_CLOSE_CONN) + header.append(payload_length) + with self._send_lock: + self.request.send(header + payload) + def send_text(self, message, opcode=OPCODE_TEXT): """ Important: Fragmented(=continuation) messages are not supported since @@ -260,12 +384,8 @@ def send_text(self, message, opcode=OPCODE_TEXT): if not message: logger.warning("Can\'t send message, message is not valid UTF-8") return False - elif sys.version_info < (3,0) and (isinstance(message, str) or isinstance(message, unicode)): - pass - elif isinstance(message, str): - pass - else: - logger.warning('Can\'t send message, message has to be a string or bytes. Given type is %s' % type(message)) + elif not isinstance(message, str): + logger.warning('Can\'t send message, message has to be a string or bytes. Got %s' % type(message)) return False header = bytearray() @@ -293,7 +413,8 @@ def send_text(self, message, opcode=OPCODE_TEXT): raise Exception("Message is too big. Consider breaking it into chunks.") return - self.request.send(header + payload) + with self._send_lock: + self.request.send(header + payload) def read_http_headers(self): headers = {} @@ -326,7 +447,8 @@ def handshake(self): return response = self.make_handshake_response(key) - self.handshake_done = self.request.send(response.encode()) + with self._send_lock: + self.handshake_done = self.request.send(response.encode()) self.valid_client = True self.server._new_client_(self)