diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..68aeaf9 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,73 @@ +version: 2.1 + +orbs: + python: circleci/python@0.2.1 + +jobs: + test: + docker: + - image: circleci/python:3.7 + steps: + - checkout + - run: + name: install dependencies + command: | + pip install -r requirements.txt + - run: + name: run tests + command: | + pytest + - store_artifacts: + path: test-reports + destination: test-reports + deploy: + executor: python/default + steps: + - checkout + - python/load-cache + - run: + name: install twine + command: | + sudo pip install twine + - run: + name: verify git tag vs. version + command: | + python setup.py verify + - run: + name: create packages + command: | + python setup.py sdist + python setup.py bdist_wheel + - run: + name: setup pypi credentials + command: | + echo -e "[pypi]" >> ~/.pypirc + echo -e "username = $PYPI_USERNAME" >> ~/.pypirc + echo -e "password = $PYPI_PASSWORD" >> ~/.pypirc + - run: + name: upload to cheeseshop if a tag found + command: | + tag=`git tag --points-at HEAD` + if [ $tag ]; then + echo Uploading + python -m twine upload dist/* + fi + +workflows: + test: + jobs: + - test: + filters: + branches: + only: development + new_release: + jobs: + - test: + filters: + branches: + only: master + tags: + only: /v[0-9]+(\.[0-9]+)*/ + - deploy: + requires: + - test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..dc44a8d --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +*.pyc +.cache +.pytest_cache +.tox +.env +*.egg-info diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..db7febd --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2018 Johan Hanssen Seferidis + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index eb12100..f0ac0da 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,26 @@ Websocket Server ======================= +[![CircleCI](https://circleci.com/gh/Pithikos/python-websocket-server/tree/master.svg?style=svg)](https://circleci.com/gh/Pithikos/python-websocket-server/tree/master) [![PyPI version](https://badge.fury.io/py/websocket-server.svg)](https://badge.fury.io/py/websocket-server) + A minimal Websockets Server in Python with no external dependencies. - * Works with Python2 and Python3 + * Python3.6+ * Clean simple API * Multiple clients * No dependencies - -Notice that this implementation does not support the more advanced features -like SSL etc. The project is focused mainly on making it easy to run a -websocket server for prototyping, testing or for making a GUI for your application. + +Notice this project is focused mainly on making it easy to run a websocket server for prototyping, testing or for making a GUI for your application. Thus not all possible features of Websockets are supported. + + +Installation +======================= + +Install with pip + + pip install websocket-server + +For coding details have a look at the [*server.py*](https://github.com/Pithikos/python-websocket-server/blob/master/server.py) example and the [API](https://github.com/Pithikos/python-websocket-server#api). Usage @@ -18,18 +28,16 @@ Usage You can get a feel of how to use the websocket server by running python server.py - + Then just open `client.html` in your browser and you should be able to send and receive messages. -Using in your project -======================= -You can either simply copy/paste the *websocket_server.py* file in your project and use it directly -or you can install the project directly from PyPi: +Testing +======= - pip install websocket-server +Run all tests -For coding details have a look at the [*server.py*](https://github.com/Pithikos/python-websocket-server/blob/master/server.py) example and the [API](https://github.com/Pithikos/python-websocket-server#api). + pytest API @@ -39,29 +47,45 @@ The API is simply methods and properties of the `WebsocketServer` class. ## WebsocketServer -The WebsocketServer takes two arguments: a `port` and a `hostname`. -By default `localhost` is used. However if you want to be able and connect -to the server from the network you need to pass `0.0.0.0` as hostname. +The WebsocketServer can be initialized with the below parameters. -###Properties +*`port`* - The port clients will need to connect to. + +*`host`* - By default the `127.0.0.1` is used which allows connections only from the current machine. If you wish to allow all network machines to connect, you need to pass `0.0.0.0` as hostname. + +*`loglevel`* - logging level to print. By default WARNING is used. You can use `logging.DEBUG` or `logging.INFO` for more verbose output. + +*`key`* - If using SSL, this is the path to the key. + +*`cert`* - If using SSL, this is the path to the certificate. + + +### Properties | Property | Description | |----------|----------------------| | clients | A list of `client` | -###Methods +### Methods | Method | Description | Takes | Gives | |-----------------------------|---------------------------------------------------------------------------------------|-----------------|-------| +| `run_forever()` | Runs server until shutdown_gracefully or shutdown_abruptly are called. | threaded: run server on its own thread if True | None | | `set_fn_new_client()` | Sets a callback function that will be called for every new `client` connecting to us | function | None | | `set_fn_client_left()` | Sets a callback function that will be called for every `client` disconnecting from us | function | None | | `set_fn_message_received()` | Sets a callback function that will be called when a `client` sends a message | function | None | | `send_message()` | Sends a `message` to a specific `client`. The message is a simple string. | client, message | None | | `send_message_to_all()` | Sends a `message` to **all** connected clients. The message is a simple string. | message | None | +| `disconnect_clients_gracefully()` | Disconnect all connected clients by sending a websocket CLOSE handshake. | Optional: status, reason | None | +| `disconnect_clients_abruptly()` | Disconnect all connected clients. Clients won't be aware until they try to send some data. | None | None | +| `shutdown_gracefully()` | Disconnect clients with a CLOSE handshake and shutdown server. | Optional: status, reason | None | +| `shutdown_abruptly()` | Disconnect clients and shutdown server with no handshake. | None | None | +| `deny_new_connections()` | Close connection for new clients. | Optional: status, reason | None | +| `allow_new_connections()` | Allows back connection for new clients. | | None | -###Callback functions +### Callback functions | Set by | Description | Parameters | |-----------------------------|---------------------------------------------------|-------------------------| @@ -70,31 +94,42 @@ to the server from the network you need to pass `0.0.0.0` as hostname. | `set_fn_message_received()` | Called when a `client` sends a `message` | client, server, message | -The client passed to the callback is the client that left, sent the message, etc. The server might not have any use to use. However it is -passed in case you want to send messages to clients. +The client passed to the callback is the client that left, sent the message, etc. The server might not have any use to use. However it is passed in case you want to send messages to clients. Example: -```` +````py +import logging +from websocket_server import WebsocketServer + +def new_client(client, server): + server.send_message_to_all("Hey all, a new client has joined us") + +server = WebsocketServer(host='127.0.0.1', port=13254, loglevel=logging.INFO) +server.set_fn_new_client(new_client) +server.run_forever() +```` +Example (SSL): +````py +import logging from websocket_server import WebsocketServer def new_client(client, server): server.send_message_to_all("Hey all, a new client has joined us") -server = WebsocketServer(13254) +server = WebsocketServer(host='127.0.0.1', port=13254, loglevel=logging.INFO, key="key.pem", cert="cert.pem") server.set_fn_new_client(new_client) server.run_forever() -```` +```` -##Client +## Client Client is just a dictionary passed along methods. -```` +```py { 'id' : client_id, 'handler' : client_handler, 'address' : (addr, port) } -```` - +``` diff --git a/docs/release-workflow.md b/docs/release-workflow.md new file mode 100644 index 0000000..316e5d8 --- /dev/null +++ b/docs/release-workflow.md @@ -0,0 +1,16 @@ +Release notes +------------- + +Releases are marked on master branch with tags. The upload to pypi is automated as long as a merge +from development comes with a tag. + +General flow + + 1. Get in dev branch + 2. Update VERSION in setup.py and releases.txt file + 3. Make a commit + 4. Merge development into master (`git merge --no-ff development`) + 4. Add corresponding version as a new tag (`git tag `) e.g. git tag v0.3.0 + 5. Push everything (`git push --tags && git push`) + +- diff --git a/releases.txt b/releases.txt new file mode 100644 index 0000000..2fda06c --- /dev/null +++ b/releases.txt @@ -0,0 +1,34 @@ +0.4 +- Python 2 and 3 support + +0.5.1 +- SSL support +- Drop Python 2 support + +0.5.4 +- Add API for shutting down server (abruptly & gracefully) + +0.5.5 +- Allow running run_forever threaded +- Fix shutting down of a server without connected clients + +0.5.6 +- Support from Python3.6+ + +0.6.0 +- Change order of params 'host' and 'port' +- Add host attribute to server + +0.6.1 +- Sending data is now thread-safe + +0.6.2 +- Add API for disconnecting clients + +0.6.3 +- Remove deprecation warnings + +0.6.4 +- Add deny_new_connections & allow_new_connections +- Fix disconnect_clients_gracefully to now take params +- Fix shutdown_gracefully unused param diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5075893 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +# Dev/test/deploy +IPython +pytest +websocket-client>=1.1.1 +twine diff --git a/server.py b/server.py index f0587c6..210c884 100644 --- a/server.py +++ b/server.py @@ -19,7 +19,7 @@ def message_received(client, server, message): PORT=9001 -server = WebsocketServer(PORT) +server = WebsocketServer(port = PORT) server.set_fn_new_client(new_client) server.set_fn_client_left(client_left) server.set_fn_message_received(message_received) diff --git a/setup.py b/setup.py index b495e1b..684b88d 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,46 @@ -from setuptools import setup, find_packages +import os +import sys +import re +import subprocess +import shlex + +try: + from setuptools import setup, find_packages + from setuptools.command.install import install +except ImportError: + from distutils.core import setup, find_packages + from distutils.command.install import install + + +VERSION = '0.6.4' + + +def get_tag_version(): + cmd = 'git tag --points-at HEAD' + versions = subprocess.check_output(shlex.split(cmd)).splitlines() + if not versions: + return None + if len(versions) != 1: + sys.exit(f"Trying to get tag via git: Expected excactly one tag, got {len(versions)}") + version = versions[0].decode() + if re.match('^v[0-9]', version): + version = version[1:] + return version + + +class VerifyVersionCommand(install): + """ Custom command to verify that the git tag matches our version """ + description = 'verify that the git tag matches our version' + + def run(self): + tag_version = get_tag_version() + if tag_version and tag_version != VERSION: + sys.exit(f"Git tag: {tag} does not match the version of this app: {VERSION}") + setup( name='websocket_server', - version='0.4', + version=VERSION, packages=find_packages("."), url='https://github.com/Pithikos/python-websocket-server', license='MIT', @@ -12,4 +50,8 @@ ], description='A simple fully working websocket-server in Python with no external dependencies', platforms='any', + cmdclass={ + 'verify': VerifyVersionCommand, + }, + python_requires=">=3.6", ) diff --git a/tests/_bootstrap_.py b/tests/_bootstrap_.py deleted file mode 100644 index 269a495..0000000 --- a/tests/_bootstrap_.py +++ /dev/null @@ -1,4 +0,0 @@ -#Bootstrap -import sys, os -if 'python-websockets-server' in os.getcwd(): - sys.path.insert(0, '..') diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d8ddf26 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,92 @@ +import logging +from time import sleep +from threading import Thread + +import pytest +import websocket # websocket-client + +# Add path to source code +import sys, os +if os.getcwd().endswith('tests'): + sys.path.insert(0, '..') +elif os.path.exists('websocket_server'): + sys.path.insert(0, '.') +from websocket_server import WebsocketServer + + +class TestClient(): + def __init__(self, port, threaded=True): + self.received_messages = [] + self.closes = [] + self.opens = [] + self.errors = [] + + websocket.enableTrace(True) + self.ws = websocket.WebSocketApp(f"ws://localhost:{port}/", + on_open=self.on_open, + on_message=self.on_message, + on_error=self.on_error, + on_close=self.on_close) + if threaded: + self.thread = Thread(target=self.ws.run_forever) + self.thread.daemon = True + self.thread.start() + else: + self.ws.run_forever() + + def on_message(self, ws, message): + self.received_messages.append(message) + print(f"TestClient: on_message: {message}") + + def on_error(self, ws, error): + self.errors.append(error) + print(f"TestClient: on_error: {error}") + + def on_close(self, ws, close_status_code, close_msg): + self.closes.append((close_status_code, close_msg)) + print(f"TestClient: on_close: {close_status_code} - {close_msg}") + + def on_open(self, ws): + self.opens.append(ws) + print("TestClient: on_open") + + +class TestServer(WebsocketServer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.received_messages = [] + self.set_fn_message_received(self.handle_received_message) + + def handle_received_message(self, client, server, message): + self.received_messages.append(message) + + +@pytest.fixture(scope='function') +def threaded_server(): + """ Returns the response of a server after""" + server = TestServer(loglevel=logging.DEBUG) + server.run_forever(threaded=True) + yield server + server.server_close() + + +@pytest.fixture +def session(threaded_server): + """ + Gives a simple connection to a server + """ + conn = websocket.create_connection("ws://{}:{}".format(*threaded_server.server_address)) + yield conn, threaded_server + conn.close() + + +@pytest.fixture +def client_session(threaded_server): + """ + Gives a TestClient instance connected to a server + """ + client = TestClient(port=threaded_server.port) + sleep(1) + assert client.ws.sock and client.ws.sock.connected + yield client, threaded_server + client.ws.close() diff --git a/tests/handshake.py b/tests/handshake.py deleted file mode 100644 index 1ebfe1d..0000000 --- a/tests/handshake.py +++ /dev/null @@ -1,40 +0,0 @@ -import _bootstrap_ -from websocket import * - - -handler = DummyWebsocketHandler() - - - - -pairs = [ - # Key # Response - ('zyjFH2rQwrTtNFk5lwEMQg==', '2hnZADGmT/V1/w1GJYBtttUKASY='), - ('XJuxlsdq0QrVyKwA/D9D5A==', 'tZ5RV3pw7nP9cF+HDvTd89WJKj8=') -] - - -# Test hash calculations for response -key = 'zyjFH2rQwrTtNFk5lwEMQg==' -resp = handler.calculate_response_key(key) -assert resp == '2hnZADGmT/V1/w1GJYBtttUKASY=' - - -# Test response messages -key = 'zyjFH2rQwrTtNFk5lwEMQg==' -expect = \ - 'HTTP/1.1 101 Switching Protocols\r\n'\ - 'Upgrade: websocket\r\n' \ - 'Connection: Upgrade\r\n' \ - 'Sec-WebSocket-Accept: 2hnZADGmT/V1/w1GJYBtttUKASY=\r\n'\ - '\r\n' -resp = handler.make_handshake_response(key) -assert resp == expect - - - - - - - -print("No errors") diff --git a/tests/message_lengths.py b/tests/message_lengths.py deleted file mode 100644 index 079d1c1..0000000 --- a/tests/message_lengths.py +++ /dev/null @@ -1,57 +0,0 @@ -import _bootstrap_ -from websocket import WebSocketsServer -from time import sleep -from testsuite.messages import * - -''' -This creates just a server that will send a different message to every new connection: - - 1. A message of length less than 126 - 2. A message of length 126 - 3. A message of length 127 - 4. A message of length bigger than 127 - 5. A message above 1024 - 6. A message above 65K - 7. An enormous message (well beyond 65K) - - -Reconnect to get the next message -''' - - -counter = 0 - -# Called for every client connecting (after handshake) -def new_client(client, server): - print("New client connected and was given id %d" % client['id']) - global counter - if counter == 0: - print("Sending message 1 of length %d" % len(msg_125B)) - server.send_message(client, msg_125B) - elif counter == 1: - print("Sending message 2 of length %d" % len(msg_126B)) - server.send_message(client, msg_126B) - elif counter == 2: - print("Sending message 3 of length %d" % len(msg_127B)) - server.send_message(client, msg_127B) - elif counter == 3: - print("Sending message 4 of length %d" % len(msg_208B)) - server.send_message(client, msg_208B) - elif counter == 4: - print("Sending message 5 of length %d" % len(msg_1251B)) - server.send_message(client, msg_1251B) - elif counter == 5: - print("Sending message 6 of length %d" % len(msg_68KB)) - server.send_message(client, msg_68KB) - elif counter == 6: - print("Sending message 7 of length %d" % len(msg_1500KB)) - server.send_message(client, msg_1500KB) - else: - print("No errors") - counter += 1 - - -PORT=9001 -server = WebSocketsServer(PORT) -server.set_fn_new_client(new_client) -server.run_forever() diff --git a/tests/test_handshake.py b/tests/test_handshake.py new file mode 100644 index 0000000..74ace26 --- /dev/null +++ b/tests/test_handshake.py @@ -0,0 +1,17 @@ +from websocket_server import WebSocketHandler + + +def test_hash_calculations_for_response(): + assert WebSocketHandler.calculate_response_key('zyjFH2rQwrTtNFk5lwEMQg==') == '2hnZADGmT/V1/w1GJYBtttUKASY=' + + +def test_response_messages(): + key = 'zyjFH2rQwrTtNFk5lwEMQg==' + expected = \ + 'HTTP/1.1 101 Switching Protocols\r\n'\ + 'Upgrade: websocket\r\n' \ + 'Connection: Upgrade\r\n' \ + 'Sec-WebSocket-Accept: 2hnZADGmT/V1/w1GJYBtttUKASY=\r\n'\ + '\r\n' + handshake_content = WebSocketHandler.make_handshake_response(key) + assert handshake_content == expected diff --git a/tests/test_message_lengths.py b/tests/test_message_lengths.py new file mode 100644 index 0000000..03f9d96 --- /dev/null +++ b/tests/test_message_lengths.py @@ -0,0 +1,67 @@ +def test_text_message_of_length_1(session): + conn, server = session + server.send_message_to_all('$') + assert conn.recv() == '$' + + +def test_text_message_of_length_125B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqr125' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_126B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrs126' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_127B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrst127' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_208B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvw208' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_1251B(session): + conn, server = session + msg = ('abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqr125'*10)+'1' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_68KB(session): + conn, server = session + msg = '$'+('a'*67993)+'68000'+'^' + assert len(msg) == 68000 + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_1500KB(session): + """ An enormous message (well beyond 65K) """ + conn, server = session + msg = '$'+('a'*1499991)+'1500000'+'^' + assert len(msg) == 1500000 + server.send_message_to_all(msg) + assert conn.recv() == msg diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..d03ce08 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,159 @@ +from time import sleep +import threading + +import websocket +import pytest + + +class TestServerThreadedWithoutClient(): + def test_run_forever(self, threaded_server): + assert threaded_server.thread + assert not isinstance(threaded_server.thread, threading._MainThread) + assert threaded_server.thread.is_alive() + + def test_attributes(self, threaded_server): + tpl = threaded_server.server_address + assert threaded_server.port == tpl[1] + assert threaded_server.host == tpl[0] + + def test_shutdown(self, threaded_server): + assert threaded_server.thread.is_alive() + + # Shutdown de-facto way + # REF: https://docs.python.org/3/library/socketserver.html + # "Tell the serve_forever() loop to stop and + # wait until it does. shutdown() must be called while serve_forever() + # is running in a different thread otherwise it will deadlock." + threaded_server.shutdown() + assert not threaded_server.thread.is_alive() + + def test_shutdown_gracefully_without_clients(self, threaded_server): + assert threaded_server.thread.is_alive() + threaded_server.shutdown_gracefully() + assert not threaded_server.thread.is_alive() + assert threaded_server.socket.fileno() <= 0 + + def test_shutdown_abruptly_without_clients(self, threaded_server): + assert threaded_server.thread.is_alive() + threaded_server.shutdown_abruptly() + assert not threaded_server.thread.is_alive() + assert threaded_server.socket.fileno() <= 0 + + +class TestServerThreadedWithClient(): + def test_send_close(self, client_session): + """ + Ensure client stops receiving data once we send_close (socket is still open) + """ + client, server = client_session + assert client.received_messages == [] + + server.send_message_to_all("test1") + sleep(0.5) + assert client.received_messages == ["test1"] + + # After CLOSE, client should not be receiving any messages + server.clients[-1]["handler"].send_close() + sleep(0.5) + server.send_message_to_all("test2") + sleep(0.5) + assert client.received_messages == ["test1"] + + def test_shutdown_gracefully(self, client_session): + client, server = client_session + assert client.ws.sock and client.ws.sock.connected + assert server.socket.fileno() > 0 + + server.shutdown_gracefully() + sleep(0.5) + + # Ensure all parties disconnected + assert not client.ws.sock + assert server.socket.fileno() == -1 + assert not server.clients + + def test_shutdown_abruptly(self, client_session): + client, server = client_session + assert client.ws.sock and client.ws.sock.connected + assert server.socket.fileno() > 0 + + server.shutdown_abruptly() + sleep(0.5) + + # Ensure server socket died + assert server.socket.fileno() == -1 + + # Ensure client handler terminated + assert server.received_messages == [] + assert client.errors == [] + client.ws.send("1st msg after server shutdown") + sleep(0.5) + + # Note the message is received since the client handler + # will terminate only once it has received the last message + # and break out of the keep_alive loop. Any consecutive messages + # will not be received though. + assert server.received_messages == ["1st msg after server shutdown"] + assert len(client.errors) == 1 + assert isinstance(client.errors[0], websocket._exceptions.WebSocketConnectionClosedException) + + # Try to send 2nd message + with pytest.raises(websocket._exceptions.WebSocketConnectionClosedException): + client.ws.send("2nd msg after server shutdown") + + def test_client_closes_gracefully(self, session): + client, server = session + assert client.connected + assert server.clients + old_client_handler = server.clients[0]["handler"] + client.close() + assert not client.connected + + # Ensure server closed connection. + # We test this by having the server trying to send + # data to the client + assert not server.clients + with pytest.raises(BrokenPipeError): + old_client_handler.connection.send(b"test") + + def test_disconnect_clients_abruptly(self, session): + client, server = session + assert client.connected + assert server.clients + server.disconnect_clients_abruptly() + assert not server.clients + + # Client won't be aware until trying to write more data + with pytest.raises(BrokenPipeError): + for i in range(3): + client.send("test") + sleep(0.2) + + def test_disconnect_clients_gracefully(self, session): + client, server = session + assert client.connected + assert server.clients + server.disconnect_clients_gracefully() + assert not server.clients + + # Client won't be aware until trying to write more data + with pytest.raises(BrokenPipeError): + for i in range(3): + client.send("test") + sleep(0.2) + + def test_deny_new_connections(self, threaded_server): + url = "ws://{}:{}".format(*threaded_server.server_address) + server = threaded_server + server.deny_new_connections(status=1013, reason=b"Please try re-connecting later") + + conn = websocket.create_connection(url) + try: + conn.send("test") + except websocket.WebSocketProtocolException as e: + assert 'Invalid close opcode' in e.args[0] + assert not server.clients + + server.allow_new_connections() + conn = websocket.create_connection(url) + conn.send("test") diff --git a/tests/test_text_messages.py b/tests/test_text_messages.py new file mode 100644 index 0000000..b4f2d0f --- /dev/null +++ b/tests/test_text_messages.py @@ -0,0 +1,124 @@ +def test_text_message_of_length_1(session): + conn, server = session + server.send_message_to_all('$') + assert conn.recv() == '$' + + +def test_text_message_of_length_125B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqr125' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_126B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrs126' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_127B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrst127' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_208B(session): + conn, server = session + msg = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvw208' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_1251B(session): + conn, server = session + msg = ('abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ + 'abcdefghijklmnopqr125'*10)+'1' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_68KB(session): + conn, server = session + msg = '$'+('a'*67993)+'68000'+'^' + assert len(msg) == 68000 + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_of_length_1500KB(session): + """ An enormous message (well beyond 65K) """ + conn, server = session + msg = '$'+('a'*1499991)+'1500000'+'^' + assert len(msg) == 1500000 + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_with_unicode_characters(session): + conn, server = session + msg = '$äüö^' + server.send_message_to_all(msg) + assert conn.recv() == msg + + +def test_text_message_stress_bursts(session): + """ Scenario: server sends multiple different message to the same conn + at once """ + from threading import Thread + NUM_THREADS = 100 + MESSAGE_LEN = 1000 + conn, server = session + messages_received = [] + + # Threads receing + threads_receiving = [] + for i in range(NUM_THREADS): + th = Thread( + target=lambda fn: messages_received.append(fn()), + args=(conn.recv,) + ) + th.daemon = True + threads_receiving.append(th) + + # Threads sending different characters each of them + threads_sending = [] + for i in range(NUM_THREADS): + message = chr(i)*MESSAGE_LEN + th = Thread( + target=server.send_message_to_all, + args=(message,) + ) + th.daemon = True + threads_sending.append(th) + + # Run scenario + for th in threads_receiving: + th.start() + for th in threads_sending: + th.start() + + # Wait for all threads to finish + print('WAITING FOR THREADS TO FINISH') + for th in threads_receiving: + th.join() + for th in threads_sending: + th.join() + + for message in messages_received: + first_char = message[0] + assert message.count(first_char) == len(message) + + print() diff --git a/tests/testsuite/__init__.py b/tests/testsuite/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/testsuite/messages.py b/tests/testsuite/messages.py deleted file mode 100644 index 5d0a93d..0000000 --- a/tests/testsuite/messages.py +++ /dev/null @@ -1,21 +0,0 @@ -# -# Fixed messages by length -# Every message ends with its length.. -# - -msg_125B = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ - 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ - 'abcdefghijklmnopqr125' -msg_126B = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ - 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ - 'abcdefghijklmnopqrs126' -msg_127B = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ - 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ - 'abcdefghijklmnopqrst127' -msg_208B = 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ - 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ - 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'\ - 'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvw208' -msg_1251B = (msg_125B*10)+'1' # 1251 -msg_68KB = ('a'*67995)+'68000' # 68000 -msg_1500KB = ('a'*1500000)+'1500000' # 1.5Mb diff --git a/websocket_server/thread.py b/websocket_server/thread.py new file mode 100644 index 0000000..a474203 --- /dev/null +++ b/websocket_server/thread.py @@ -0,0 +1,38 @@ +import threading + + +class ThreadWithLoggedException(threading.Thread): + """ + Similar to Thread but will log exceptions to passed logger. + + Args: + logger: Logger instance used to log any exception in child thread + + Exception is also reachable via .exception from the main thread. + """ + + DIVIDER = "*"*80 + + def __init__(self, *args, **kwargs): + try: + self.logger = kwargs.pop("logger") + except KeyError: + raise Exception("Missing 'logger' in kwargs") + super().__init__(*args, **kwargs) + self.exception = None + + def run(self): + try: + if self._target is not None: + self._target(*self._args, **self._kwargs) + except Exception as exception: + thread = threading.current_thread() + self.exception = exception + self.logger.exception(f"{self.DIVIDER}\nException in child thread {thread}: {exception}\n{self.DIVIDER}") + finally: + del self._target, self._args, self._kwargs + + +class WebsocketServerThread(ThreadWithLoggedException): + """Dummy wrapper to make debug messages a bit more readable""" + pass diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index 550099b..c954c34 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -1,18 +1,21 @@ # Author: Johan Hanssen Seferidis # License: MIT -import re, sys +import sys import struct +import ssl from base64 import b64encode from hashlib import sha1 +import logging +from socket import error as SocketError +import errno +import threading +from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler -if sys.version_info[0] < 3 : - from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler -else: - from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler - - +from websocket_server.thread import WebsocketServerThread +logger = logging.getLogger(__name__) +logging.basicConfig() ''' +-+-+-+-+-------+-+-------------+-------------------------------+ @@ -37,268 +40,453 @@ PAYLOAD_LEN_EXT16 = 0x7e PAYLOAD_LEN_EXT64 = 0x7f -OPCODE_TEXT = 0x01 -CLOSE_CONN = 0x8 +OPCODE_CONTINUATION = 0x0 +OPCODE_TEXT = 0x1 +OPCODE_BINARY = 0x2 +OPCODE_CLOSE_CONN = 0x8 +OPCODE_PING = 0x9 +OPCODE_PONG = 0xA +CLOSE_STATUS_NORMAL = 1000 +DEFAULT_CLOSE_REASON = bytes('', encoding='utf-8') -# -------------------------------- API --------------------------------- - class API(): - def run_forever(self): - try: - print("Listening on port %d for clients.." % self.port) - self.serve_forever() - except KeyboardInterrupt: - self.server_close() - print("Server terminated.") - except Exception as e: - print("ERROR: WebSocketsServer: "+str(e)) - exit(1) - def new_client(self, client, server): - pass - def client_left(self, client, server): - pass - def message_received(self, client, server, message): - pass - def set_fn_new_client(self, fn): - self.new_client=fn - def set_fn_client_left(self, fn): - self.client_left=fn - def set_fn_message_received(self, fn): - self.message_received=fn - def send_message(self, client, msg): - self._unicast_(client, msg) - def send_message_to_all(self, msg): - self._multicast_(msg) - - - -# ------------------------- Implementation ----------------------------- -class WebsocketServer(ThreadingMixIn, TCPServer, API): + def run_forever(self, threaded=False): + return self._run_forever(threaded) + + def new_client(self, client, server): + pass + + def client_left(self, client, server): + pass - allow_reuse_address = True - daemon_threads = True # comment to keep threads alive until finished - - ''' - clients is a list of dict: - { - 'id' : id, - 'handler' : handler, - 'address' : (addr, port) - } - ''' - clients=[] - id_counter=0 - - def __init__(self, port, host='127.0.0.1'): - self.port=port - TCPServer.__init__(self, (host, port), WebSocketHandler) - - def _message_received_(self, handler, msg): - self.message_received(self.handler_to_client(handler), self, msg) - - def _new_client_(self, handler): - self.id_counter += 1 - client={ - 'id' : self.id_counter, - 'handler' : handler, - 'address' : handler.client_address - } - self.clients.append(client) - self.new_client(client, self) - - def _client_left_(self, handler): - client=self.handler_to_client(handler) - self.client_left(client, self) - if client in self.clients: - self.clients.remove(client) - - def _unicast_(self, to_client, msg): - to_client['handler'].send_message(msg) - - def _multicast_(self, msg): - for client in self.clients: - self._unicast_(client, msg) - - def handler_to_client(self, handler): - for client in self.clients: - if client['handler'] == handler: - return client + def message_received(self, client, server, message): + pass + def set_fn_new_client(self, fn): + self.new_client = fn + def set_fn_client_left(self, fn): + self.client_left = fn -class WebSocketHandler(StreamRequestHandler): + def set_fn_message_received(self, fn): + self.message_received = fn - def __init__(self, socket, addr, server): - self.server=server - StreamRequestHandler.__init__(self, socket, addr, server) - - def setup(self): - StreamRequestHandler.setup(self) - self.keep_alive = True - self.handshake_done = False - self.valid_client = False - - def handle(self): - while self.keep_alive: - if not self.handshake_done: - self.handshake() - elif self.valid_client: - self.read_next_message() - - def read_bytes(self, num): - # python3 gives ordinal of byte directly - bytes = self.rfile.read(num) - if sys.version_info[0] < 3: - return map(ord, bytes) - else: - return bytes - - def read_next_message(self): - - b1, b2 = self.read_bytes(2) - - fin = b1 & FIN - opcode = b1 & OPCODE - masked = b2 & MASKED - payload_length = b2 & PAYLOAD_LEN - - if not b1: - print("Client closed connection.") - self.keep_alive = 0 - return - if opcode == CLOSE_CONN: - print("Client asked to close connection.") - self.keep_alive = 0 - return - if not masked: - print("Client must always be masked.") - self.keep_alive = 0 - return - - if payload_length == 126: - payload_length = struct.unpack(">H", self.rfile.read(2))[0] - elif payload_length == 127: - payload_length = struct.unpack(">Q", self.rfile.read(8))[0] - - masks = self.read_bytes(4) - decoded = "" - for char in self.read_bytes(payload_length): - char ^= masks[len(decoded) % 4] - decoded += chr(char) - self.server._message_received_(self, decoded) - - def send_message(self, message): - self.send_text(message) - - def send_text(self, message): - ''' - NOTES - Fragmented(=continuation) messages are not being used since their usage - is needed in very limited cases - when we don't know the payload length. - ''' - - # Validate message - if isinstance(message, bytes): - message = try_decode_UTF8(message) # this is slower but assures we have UTF-8 - if not message: - print("Can\'t send message, message is not valid UTF-8") - return False - elif isinstance(message, str) or isinstance(message, unicode): - pass - else: - print('Can\'t send message, message has to be a string or bytes. Given type is %s' % type(message)) - return False - - header = bytearray() - payload = encode_to_UTF8(message) - payload_length = len(payload) - - # Normal payload - if payload_length <= 125: - header.append(FIN | OPCODE_TEXT) - header.append(payload_length) - - # Extended payload - elif payload_length >= 126 and payload_length <= 65535: - header.append(FIN | OPCODE_TEXT) - header.append(PAYLOAD_LEN_EXT16) - header.extend(struct.pack(">H", payload_length)) - - # Huge extended payload - elif payload_length < 18446744073709551616: - header.append(FIN | OPCODE_TEXT) - header.append(PAYLOAD_LEN_EXT64) - header.extend(struct.pack(">Q", payload_length)) - - else: - raise Exception("Message is too big. Consider breaking it into chunks.") - return - - self.request.send(header + payload) - - def handshake(self): - message = self.request.recv(1024).decode().strip() - upgrade = re.search('\nupgrade[\s]*:[\s]*websocket', message.lower()) - if not upgrade: - self.keep_alive = False - return - key = re.search('\n[sS]ec-[wW]eb[sS]ocket-[kK]ey[\s]*:[\s]*(.*)\r\n', message) - if key: - key = key.group(1) - else: - print("Client tried to connect but was missing a key") - self.keep_alive = False - return - response = self.make_handshake_response(key) - self.handshake_done = self.request.send(response.encode()) - self.valid_client = True - self.server._new_client_(self) - - def make_handshake_response(self, key): - return \ - 'HTTP/1.1 101 Switching Protocols\r\n'\ - 'Upgrade: websocket\r\n' \ - 'Connection: Upgrade\r\n' \ - 'Sec-WebSocket-Accept: %s\r\n' \ - '\r\n' % self.calculate_response_key(key) - - def calculate_response_key(self, key): - GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' - hash = sha1(key.encode() + GUID.encode()) - response_key = b64encode(hash.digest()).strip() - return response_key.decode('ASCII') - - def finish(self): - self.server._client_left_(self) + def send_message(self, client, msg): + self._unicast(client, msg) + def send_message_to_all(self, msg): + self._multicast(msg) + def deny_new_connections(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + self._deny_new_connections(status, reason) -def encode_to_UTF8(data): - try: - return data.encode('UTF-8') - except UnicodeEncodeError as e: - print("Could not encode data to UTF-8 -- %s" % e) - return False - except Exception as e: - raise(e) - return False + def allow_new_connections(self): + self._allow_new_connections() + def shutdown_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + self._shutdown_gracefully(status, reason) + def shutdown_abruptly(self): + self._shutdown_abruptly() -def try_decode_UTF8(data): - try: - return data.decode('utf-8') - except UnicodeDecodeError: - return False - except Exception as e: - raise(e) - - - -# This is only for testing purposes -class DummyWebsocketHandler(WebSocketHandler): - def __init__(self, *_): + def disconnect_clients_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + self._disconnect_clients_gracefully(status, reason) + + def disconnect_clients_abruptly(self): + self._disconnect_clients_abruptly() + + +class WebsocketServer(ThreadingMixIn, TCPServer, API): + """ + A websocket server waiting for clients to connect. + + Args: + port(int): Port to bind to + host(str): Hostname or IP to listen for connections. By default 127.0.0.1 + is being used. To accept connections from any client, you should use + 0.0.0.0. + loglevel: Logging level from logging module to use for logging. By default + warnings and errors are being logged. + + Properties: + clients(list): A list of connected clients. A client is a dictionary + like below. + { + 'id' : id, + 'handler' : handler, + 'address' : (addr, port) + } + """ + + allow_reuse_address = True + daemon_threads = True # comment to keep threads alive until finished + + def __init__(self, host='127.0.0.1', port=0, loglevel=logging.WARNING, key=None, cert=None): + logger.setLevel(loglevel) + TCPServer.__init__(self, (host, port), WebSocketHandler) + self.host = host + self.port = self.socket.getsockname()[1] + + self.key = key + self.cert = cert + + self.clients = [] + self.id_counter = 0 + self.thread = None + + self._deny_clients = False + + def _run_forever(self, threaded): + cls_name = self.__class__.__name__ + try: + logger.info("Listening on port %d for clients.." % self.port) + if threaded: + self.daemon = True + self.thread = WebsocketServerThread(target=super().serve_forever, daemon=True, logger=logger) + logger.info(f"Starting {cls_name} on thread {self.thread.getName()}.") + self.thread.start() + else: + self.thread = threading.current_thread() + logger.info(f"Starting {cls_name} on main thread.") + super().serve_forever() + except KeyboardInterrupt: + self.server_close() + logger.info("Server terminated.") + except Exception as e: + logger.error(str(e), exc_info=True) + sys.exit(1) + + def _message_received_(self, handler, msg): + self.message_received(self.handler_to_client(handler), self, msg) + + def _ping_received_(self, handler, msg): + handler.send_pong(msg) + + def _pong_received_(self, handler, msg): pass + + def _new_client_(self, handler): + if self._deny_clients: + status = self._deny_clients["status"] + reason = self._deny_clients["reason"] + handler.send_close(status, reason) + self._terminate_client_handler(handler) + return + + self.id_counter += 1 + client = { + 'id': self.id_counter, + 'handler': handler, + 'address': handler.client_address + } + self.clients.append(client) + self.new_client(client, self) + + def _client_left_(self, handler): + client = self.handler_to_client(handler) + self.client_left(client, self) + if client in self.clients: + self.clients.remove(client) + + def _unicast(self, receiver_client, msg): + receiver_client['handler'].send_message(msg) + + def _multicast(self, msg): + for client in self.clients: + self._unicast(client, msg) + + def handler_to_client(self, handler): + for client in self.clients: + if client['handler'] == handler: + return client + + def _terminate_client_handler(self, handler): + handler.keep_alive = False + handler.finish() + handler.connection.close() + + def _terminate_client_handlers(self): + """ + Ensures request handler for each client is terminated correctly + """ + for client in self.clients: + self._terminate_client_handler(client["handler"]) + + def _shutdown_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + """ + Send a CLOSE handshake to all connected clients before terminating server + """ + self.keep_alive = False + self._disconnect_clients_gracefully(status, reason) + self.server_close() + self.shutdown() + + def _shutdown_abruptly(self): + """ + Terminate server without sending a CLOSE handshake + """ + self.keep_alive = False + self._disconnect_clients_abruptly() + self.server_close() + self.shutdown() + + def _disconnect_clients_gracefully(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + """ + Terminate clients gracefully without shutting down the server + """ + for client in self.clients: + client["handler"].send_close(status, reason) + self._terminate_client_handlers() + + def _disconnect_clients_abruptly(self): + """ + Terminate clients abruptly (no CLOSE handshake) without shutting down the server + """ + self._terminate_client_handlers() + + def _deny_new_connections(self, status, reason): + self._deny_clients = { + "status": status, + "reason": reason, + } + + def _allow_new_connections(self): + self._deny_clients = False + + +class WebSocketHandler(StreamRequestHandler): + + def __init__(self, socket, addr, server): + self.server = server + assert not hasattr(self, "_send_lock"), "_send_lock already exists" + self._send_lock = threading.Lock() + if server.key and server.cert: + try: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(certfile=server.cert, keyfile=server.key) + socket = ssl_context.wrap_socket(socket, server_side=True) + except FileNotFoundError: + logger.warning("SSL key or certificate file not found. Please check the paths for the key and cert.") + except ssl.SSLError as e: + logger.warning(f"SSL error occurred: {e}") + StreamRequestHandler.__init__(self, socket, addr, server) + + def setup(self): + StreamRequestHandler.setup(self) + self.keep_alive = True + self.handshake_done = False + self.valid_client = False + + def handle(self): + while self.keep_alive: + if not self.handshake_done: + self.handshake() + elif self.valid_client: + self.read_next_message() + + def read_bytes(self, num): + return self.rfile.read(num) + + def read_next_message(self): + try: + b1, b2 = self.read_bytes(2) + except SocketError as e: # to be replaced with ConnectionResetError for py3 + if e.errno == errno.ECONNRESET: + logger.info("Client closed connection.") + self.keep_alive = 0 + return + b1, b2 = 0, 0 + except ValueError as e: + b1, b2 = 0, 0 + + fin = b1 & FIN + opcode = b1 & OPCODE + masked = b2 & MASKED + payload_length = b2 & PAYLOAD_LEN + + if opcode == OPCODE_CLOSE_CONN: + logger.info("Client asked to close connection.") + self.keep_alive = 0 + return + if not masked: + logger.warning("Client must always be masked.") + self.keep_alive = 0 + return + if opcode == OPCODE_CONTINUATION: + logger.warning("Continuation frames are not supported.") + return + elif opcode == OPCODE_BINARY: + logger.warning("Binary frames are not supported.") + return + elif opcode == OPCODE_TEXT: + opcode_handler = self.server._message_received_ + elif opcode == OPCODE_PING: + opcode_handler = self.server._ping_received_ + elif opcode == OPCODE_PONG: + opcode_handler = self.server._pong_received_ + else: + logger.warning("Unknown opcode %#x." % opcode) + self.keep_alive = 0 + return + + if payload_length == 126: + payload_length = struct.unpack(">H", self.rfile.read(2))[0] + elif payload_length == 127: + payload_length = struct.unpack(">Q", self.rfile.read(8))[0] + + masks = self.read_bytes(4) + message_bytes = bytearray() + for message_byte in self.read_bytes(payload_length): + message_byte ^= masks[len(message_bytes) % 4] + message_bytes.append(message_byte) + opcode_handler(self, message_bytes.decode('utf8')) + + def send_message(self, message): + self.send_text(message) + + def send_pong(self, message): + self.send_text(message, OPCODE_PONG) + + def send_close(self, status=CLOSE_STATUS_NORMAL, reason=DEFAULT_CLOSE_REASON): + """ + Send CLOSE to client + + Args: + status: Status as defined in https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 + reason: Text with reason of closing the connection + """ + if status < CLOSE_STATUS_NORMAL or status > 1015: + raise Exception(f"CLOSE status must be between 1000 and 1015, got {status}") + + header = bytearray() + payload = struct.pack('!H', status) + reason + payload_length = len(payload) + assert payload_length <= 125, "We only support short closing reasons at the moment" + + # Send CLOSE with status & reason + header.append(FIN | OPCODE_CLOSE_CONN) + header.append(payload_length) + with self._send_lock: + self.request.send(header + payload) + + def send_text(self, message, opcode=OPCODE_TEXT): + """ + Important: Fragmented(=continuation) messages are not supported since + their usage cases are limited - when we don't know the payload length. + """ + + # Validate message + if isinstance(message, bytes): + message = try_decode_UTF8(message) # this is slower but ensures we have UTF-8 + if not message: + logger.warning("Can\'t send message, message is not valid UTF-8") + return False + elif not isinstance(message, str): + logger.warning('Can\'t send message, message has to be a string or bytes. Got %s' % type(message)) + return False + + header = bytearray() + payload = encode_to_UTF8(message) + payload_length = len(payload) + + # Normal payload + if payload_length <= 125: + header.append(FIN | opcode) + header.append(payload_length) + + # Extended payload + elif payload_length >= 126 and payload_length <= 65535: + header.append(FIN | opcode) + header.append(PAYLOAD_LEN_EXT16) + header.extend(struct.pack(">H", payload_length)) + + # Huge extended payload + elif payload_length < 18446744073709551616: + header.append(FIN | opcode) + header.append(PAYLOAD_LEN_EXT64) + header.extend(struct.pack(">Q", payload_length)) + + else: + raise Exception("Message is too big. Consider breaking it into chunks.") + return + + with self._send_lock: + self.request.send(header + payload) + + def read_http_headers(self): + headers = {} + # first line should be HTTP GET + http_get = self.rfile.readline().decode().strip() + assert http_get.upper().startswith('GET') + # remaining should be headers + while True: + header = self.rfile.readline().decode().strip() + if not header: + break + head, value = header.split(':', 1) + headers[head.lower().strip()] = value.strip() + return headers + + def handshake(self): + headers = self.read_http_headers() + + try: + assert headers['upgrade'].lower() == 'websocket' + except AssertionError: + self.keep_alive = False + return + + try: + key = headers['sec-websocket-key'] + except KeyError: + logger.warning("Client tried to connect but was missing a key") + self.keep_alive = False + return + + response = self.make_handshake_response(key) + with self._send_lock: + self.handshake_done = self.request.send(response.encode()) + self.valid_client = True + self.server._new_client_(self) + + @classmethod + def make_handshake_response(cls, key): + return \ + 'HTTP/1.1 101 Switching Protocols\r\n'\ + 'Upgrade: websocket\r\n' \ + 'Connection: Upgrade\r\n' \ + 'Sec-WebSocket-Accept: %s\r\n' \ + '\r\n' % cls.calculate_response_key(key) + + @classmethod + def calculate_response_key(cls, key): + GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + hash = sha1(key.encode() + GUID.encode()) + response_key = b64encode(hash.digest()).strip() + return response_key.decode('ASCII') + + def finish(self): + self.server._client_left_(self) + + +def encode_to_UTF8(data): + try: + return data.encode('UTF-8') + except UnicodeEncodeError as e: + logger.error("Could not encode data to UTF-8 -- %s" % e) + return False + except Exception as e: + raise(e) + return False + + +def try_decode_UTF8(data): + try: + return data.decode('utf-8') + except UnicodeDecodeError: + return False + except Exception as e: + raise(e)