From 29f8d950e49b555c39d28ca1c5bf383442fcc6e2 Mon Sep 17 00:00:00 2001 From: Zach Wegner Date: Wed, 17 Jul 2019 13:35:51 -0500 Subject: [PATCH 1/7] Change websocket_server a bunch, remove unneeded APIs, add routing mechanism, add server thread --- websocket_server/websocket_server.py | 214 ++++++++------------------- 1 file changed, 61 insertions(+), 153 deletions(-) diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index 96c658b..a423685 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -1,13 +1,14 @@ # Author: Johan Hanssen Seferidis # License: MIT +import errno +import logging import sys import struct +import threading from base64 import b64encode from hashlib import sha1 -import logging from socket import error as SocketError -import errno if sys.version_info[0] < 3: from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler @@ -15,7 +16,33 @@ from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler logger = logging.getLogger(__name__) -logging.basicConfig() + +def encode_to_UTF8(data): + try: + return data.encode('UTF-8') + except UnicodeEncodeError as e: + logger.error("Could not encode data to UTF-8 -- %s" % e) + return False + except Exception as e: + raise(e) + return False + +def try_decode_UTF8(data): + try: + return data.decode('utf-8') + except UnicodeDecodeError: + return False + except Exception as e: + raise(e) + +ROUTING_TABLE = {} + +# Decorator for routing websocket handlers a la Flask +def route(path): + def decorate(fn): + ROUTING_TABLE[path] = fn + return fn + return decorate ''' +-+-+-+-+-------+-+-------------+-------------------------------+ @@ -47,122 +74,7 @@ OPCODE_PING = 0x9 OPCODE_PONG = 0xA - -# -------------------------------- API --------------------------------- - -class API(): - - def run_forever(self): - try: - logger.info("Listening on port %d for clients.." % self.port) - self.serve_forever() - except KeyboardInterrupt: - self.server_close() - logger.info("Server terminated.") - except Exception as e: - logger.error(str(e), exc_info=True) - exit(1) - - def new_client(self, client, server): - pass - - def client_left(self, client, server): - pass - - def message_received(self, client, server, message): - pass - - def set_fn_new_client(self, fn): - self.new_client = fn - - def set_fn_client_left(self, fn): - self.client_left = fn - - def set_fn_message_received(self, fn): - self.message_received = fn - - def send_message(self, client, msg): - self._unicast_(client, msg) - - def send_message_to_all(self, msg): - self._multicast_(msg) - - -# ------------------------- Implementation ----------------------------- - -class WebsocketServer(ThreadingMixIn, TCPServer, API): - """ - A websocket server waiting for clients to connect. - - Args: - port(int): Port to bind to - host(str): Hostname or IP to listen for connections. By default 127.0.0.1 - is being used. To accept connections from any client, you should use - 0.0.0.0. - loglevel: Logging level from logging module to use for logging. By default - warnings and errors are being logged. - - Properties: - clients(list): A list of connected clients. A client is a dictionary - like below. - { - 'id' : id, - 'handler' : handler, - 'address' : (addr, port) - } - """ - - allow_reuse_address = True - daemon_threads = True # comment to keep threads alive until finished - - clients = [] - id_counter = 0 - - def __init__(self, port, host='127.0.0.1', loglevel=logging.WARNING): - logger.setLevel(loglevel) - TCPServer.__init__(self, (host, port), WebSocketHandler) - self.port = self.socket.getsockname()[1] - - def _message_received_(self, handler, msg): - self.message_received(self.handler_to_client(handler), self, msg) - - def _ping_received_(self, handler, msg): - handler.send_pong(msg) - - def _pong_received_(self, handler, msg): - pass - - def _new_client_(self, handler): - self.id_counter += 1 - client = { - 'id': self.id_counter, - 'handler': handler, - 'address': handler.client_address - } - self.clients.append(client) - self.new_client(client, self) - - def _client_left_(self, handler): - client = self.handler_to_client(handler) - self.client_left(client, self) - if client in self.clients: - self.clients.remove(client) - - def _unicast_(self, to_client, msg): - to_client['handler'].send_message(msg) - - def _multicast_(self, msg): - for client in self.clients: - self._unicast_(client, msg) - - def handler_to_client(self, handler): - for client in self.clients: - if client['handler'] == handler: - return client - - class WebSocketHandler(StreamRequestHandler): - def __init__(self, socket, addr, server): self.server = server StreamRequestHandler.__init__(self, socket, addr, server) @@ -173,12 +85,20 @@ def setup(self): self.handshake_done = False self.valid_client = False + def message_stream(self): + while self.keep_alive and self.handshake_done and self.valid_client: + yield self.read_next_message() + def handle(self): while self.keep_alive: if not self.handshake_done: - self.handshake() - elif self.valid_client: - self.read_next_message() + path = self.handshake() + + if path not in ROUTING_TABLE: + logger.warning('Bad path for websocket request: %s' % path) + return + route_fn = ROUTING_TABLE[path] + route_fn(iter(self.message_stream()), self.send_message) def read_bytes(self, num): # python3 gives ordinal of byte directly @@ -193,7 +113,7 @@ def read_next_message(self): b1, b2 = self.read_bytes(2) except SocketError as e: # to be replaced with ConnectionResetError for py3 if e.errno == errno.ECONNRESET: - logger.info("Client closed connection.") + logger.debug("Client closed connection.") self.keep_alive = 0 return b1, b2 = 0, 0 @@ -206,7 +126,7 @@ def read_next_message(self): payload_length = b2 & PAYLOAD_LEN if opcode == OPCODE_CLOSE_CONN: - logger.info("Client asked to close connection.") + logger.debug("Client asked to close connection.") self.keep_alive = 0 return if not masked: @@ -220,11 +140,11 @@ def read_next_message(self): logger.warn("Binary frames are not supported.") return elif opcode == OPCODE_TEXT: - opcode_handler = self.server._message_received_ + opcode_type = 'text' elif opcode == OPCODE_PING: - opcode_handler = self.server._ping_received_ + opcode_type = 'ping' elif opcode == OPCODE_PONG: - opcode_handler = self.server._pong_received_ + opcode_type = 'pong' else: logger.warn("Unknown opcode %#x." % opcode) self.keep_alive = 0 @@ -240,7 +160,7 @@ def read_next_message(self): for message_byte in self.read_bytes(payload_length): message_byte ^= masks[len(message_bytes) % 4] message_bytes.append(message_byte) - opcode_handler(self, message_bytes.decode('utf8')) + return (opcode_type, message_bytes.decode('utf8')) def send_message(self, message): self.send_text(message) @@ -298,8 +218,9 @@ def send_text(self, message, opcode=OPCODE_TEXT): def read_http_headers(self): headers = {} # first line should be HTTP GET - http_get = self.rfile.readline().decode().strip() - assert http_get.upper().startswith('GET') + request = self.rfile.readline().decode().strip() + method, path, protocol = request.split() + assert method.startswith('GET') # remaining should be headers while True: header = self.rfile.readline().decode().strip() @@ -307,10 +228,10 @@ def read_http_headers(self): break head, value = header.split(':', 1) headers[head.lower().strip()] = value.strip() - return headers + return path, headers def handshake(self): - headers = self.read_http_headers() + path, headers = self.read_http_headers() try: assert headers['upgrade'].lower() == 'websocket' @@ -328,7 +249,7 @@ def handshake(self): response = self.make_handshake_response(key) self.handshake_done = self.request.send(response.encode()) self.valid_client = True - self.server._new_client_(self) + return path @classmethod def make_handshake_response(cls, key): @@ -346,25 +267,12 @@ def calculate_response_key(cls, key): response_key = b64encode(hash.digest()).strip() return response_key.decode('ASCII') - def finish(self): - self.server._client_left_(self) +def run_websocket_server(): + server = TCPServer(('127.0.0.1', 5001), WebSocketHandler) + server.timeout = 3 + server.allow_reuse_address = True + server.daemon_threads = True + server.serve_forever() - -def encode_to_UTF8(data): - try: - return data.encode('UTF-8') - except UnicodeEncodeError as e: - logger.error("Could not encode data to UTF-8 -- %s" % e) - return False - except Exception as e: - raise(e) - return False - - -def try_decode_UTF8(data): - try: - return data.decode('utf-8') - except UnicodeDecodeError: - return False - except Exception as e: - raise(e) +def start_websocket_server(): + threading.Thread(target=run_websocket_server, daemon=True).start() From 04dd4670f73575d12fa30ec25408cba6b7d892d1 Mon Sep 17 00:00:00 2001 From: Zach Wegner Date: Wed, 17 Jul 2019 14:09:32 -0500 Subject: [PATCH 2/7] Make sure to catch exceptions in the websocket handler --- websocket_server/websocket_server.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index a423685..fa68c95 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -90,15 +90,18 @@ def message_stream(self): yield self.read_next_message() def handle(self): - while self.keep_alive: - if not self.handshake_done: - path = self.handshake() - - if path not in ROUTING_TABLE: - logger.warning('Bad path for websocket request: %s' % path) - return - route_fn = ROUTING_TABLE[path] - route_fn(iter(self.message_stream()), self.send_message) + try: + while self.keep_alive: + if not self.handshake_done: + path = self.handshake() + + if path not in ROUTING_TABLE: + logger.warning('Bad path for websocket request: %s' % path) + return + route_fn = ROUTING_TABLE[path] + route_fn(iter(self.message_stream()), self.send_message) + except BrokenPipeError as e: + logger.info("Client connection broken.") def read_bytes(self, num): # python3 gives ordinal of byte directly From 9dd1e4cc1e0fb913c1ba40c6d3a4154648051ac1 Mon Sep 17 00:00:00 2001 From: Zach Wegner Date: Tue, 23 Jul 2019 14:21:45 -0500 Subject: [PATCH 3/7] Improve websocket interface to let each route handle sending and receiving, the message_stream() iterable isn't a good interface. Also allow pings in the handler loop, and clean up a couple small things --- websocket_server/websocket_server.py | 44 +++++++++++++--------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index fa68c95..f9b898c 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -44,21 +44,19 @@ def decorate(fn): return fn return decorate -''' -+-+-+-+-+-------+-+-------------+-------------------------------+ - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -+-+-+-+-+-------+-+-------------+-------------------------------+ -|F|R|R|R| opcode|M| Payload len | Extended payload length | -|I|S|S|S| (4) |A| (7) | (16/64) | -|N|V|V|V| |S| | (if payload len==126/127) | -| |1|2|3| |K| | | -+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + -| Extended payload length continued, if payload len == 127 | -+ - - - - - - - - - - - - - - - +-------------------------------+ -| Payload Data continued ... | -+---------------------------------------------------------------+ -''' +# +-+-+-+-+-------+-+-------------+-------------------------------+ +# |0| | | | | | 1 | 2 3 | +# |0|1|2|3|4 5 6 7|8|9 0 1 2 3 4 5|6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1| +# +-+-+-+-+-------+-+-------------+-------------------------------+ +# |F|R|R|R| opcode|M| Payload len | Extended payload length | +# |I|S|S|S| (4) |A| (7) | (16/64) | +# |N|V|V|V| |S| | (if payload len==126/127) | +# | |1|2|3| |K| | | +# +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + +# | Extended payload length continued, if payload len == 127 | +# + - - - - - - - - - - - - - - - +-------------------------------+ +# | Payload Data continued ... | +# +---------------------------------------------------------------+ FIN = 0x80 OPCODE = 0x0f @@ -85,10 +83,6 @@ def setup(self): self.handshake_done = False self.valid_client = False - def message_stream(self): - while self.keep_alive and self.handshake_done and self.valid_client: - yield self.read_next_message() - def handle(self): try: while self.keep_alive: @@ -99,7 +93,9 @@ def handle(self): logger.warning('Bad path for websocket request: %s' % path) return route_fn = ROUTING_TABLE[path] - route_fn(iter(self.message_stream()), self.send_message) + logger.info('websocket client connected to %s', path) + route_fn(self) + logger.info('websocket client disconnected from %s', path) except BrokenPipeError as e: logger.info("Client connection broken.") @@ -168,8 +164,11 @@ def read_next_message(self): def send_message(self, message): self.send_text(message) + def send_ping(self, message): + self.send_text(message, opcode=OPCODE_PING) + def send_pong(self, message): - self.send_text(message, OPCODE_PONG) + self.send_text(message, opcode=OPCODE_PONG) def send_text(self, message, opcode=OPCODE_TEXT): """ @@ -271,9 +270,8 @@ def calculate_response_key(cls, key): return response_key.decode('ASCII') def run_websocket_server(): + TCPServer.allow_reuse_address = True server = TCPServer(('127.0.0.1', 5001), WebSocketHandler) - server.timeout = 3 - server.allow_reuse_address = True server.daemon_threads = True server.serve_forever() From bde23bf125848c6a9717ba7361032d0aba0a9291 Mon Sep 17 00:00:00 2001 From: Zach Wegner Date: Thu, 15 Aug 2019 17:00:18 -0500 Subject: [PATCH 4/7] Bugfix: we shouldn't be looping while keep_alive is set, we're only serving one request --- websocket_server/websocket_server.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index f9b898c..518bfb3 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -85,17 +85,16 @@ def setup(self): def handle(self): try: - while self.keep_alive: - if not self.handshake_done: - path = self.handshake() - - if path not in ROUTING_TABLE: - logger.warning('Bad path for websocket request: %s' % path) - return - route_fn = ROUTING_TABLE[path] - logger.info('websocket client connected to %s', path) - route_fn(self) - logger.info('websocket client disconnected from %s', path) + if not self.handshake_done: + path = self.handshake() + + if path not in self.server.routing_table: + logger.warning('Bad path for websocket request: %s' % path) + return + route_fn = self.server.routing_table[path] + logger.info('websocket client connected to %s', path) + route_fn(self) + logger.info('websocket client disconnected from %s', path) except BrokenPipeError as e: logger.info("Client connection broken.") @@ -108,6 +107,8 @@ def read_bytes(self, num): return bytes def read_next_message(self): + if not self.keep_alive: + return None try: b1, b2 = self.read_bytes(2) except SocketError as e: # to be replaced with ConnectionResetError for py3 From 96b948d657b0ac5f82b92c6e141e94e944c55a62 Mon Sep 17 00:00:00 2001 From: Zach Wegner Date: Thu, 15 Aug 2019 17:06:29 -0500 Subject: [PATCH 5/7] Wrap all serving code into a class, and make sure we're using the ThreadingMixIn --- websocket_server/websocket_server.py | 37 +++++++++++++++------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index 518bfb3..ebae6c7 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -35,15 +35,6 @@ def try_decode_UTF8(data): except Exception as e: raise(e) -ROUTING_TABLE = {} - -# Decorator for routing websocket handlers a la Flask -def route(path): - def decorate(fn): - ROUTING_TABLE[path] = fn - return fn - return decorate - # +-+-+-+-+-------+-+-------------+-------------------------------+ # |0| | | | | | 1 | 2 3 | # |0|1|2|3|4 5 6 7|8|9 0 1 2 3 4 5|6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1| @@ -270,11 +261,23 @@ def calculate_response_key(cls, key): response_key = b64encode(hash.digest()).strip() return response_key.decode('ASCII') -def run_websocket_server(): - TCPServer.allow_reuse_address = True - server = TCPServer(('127.0.0.1', 5001), WebSocketHandler) - server.daemon_threads = True - server.serve_forever() - -def start_websocket_server(): - threading.Thread(target=run_websocket_server, daemon=True).start() +class WebSocketServer(ThreadingMixIn, TCPServer): + allow_reuse_address = True + daemon_threads = True + + def __init__(self, host, port): + address = (host, port) + # Have to init in both superclasses? Weird... + super(TCPServer, self).__init__(address, WebSocketHandler) + super(ThreadingMixIn, self).__init__(address, WebSocketHandler) + self.routing_table = {} + + # Decorator for routing websocket handlers a la Flask + def route(self, path): + def decorate(fn): + self.routing_table[path] = fn + return fn + return decorate + + def start(self): + threading.Thread(target=self.serve_forever, daemon=True).start() From d59c63bef672203f4b62ed9bce9c5ccb6fddb801 Mon Sep 17 00:00:00 2001 From: Zach Wegner Date: Thu, 15 Aug 2019 17:22:11 -0500 Subject: [PATCH 6/7] Change example server to match new API --- client.html | 2 +- server.py | 37 ++++++++++++++++--------------------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/client.html b/client.html index 598a6a6..04f8e38 100644 --- a/client.html +++ b/client.html @@ -9,7 +9,7 @@ function init() { // Connect to Web Socket - ws = new WebSocket("ws://localhost:9001/"); + ws = new WebSocket("ws://127.0.0.1:5005/ws/echo"); // Set event handlers. ws.onopen = function() { diff --git a/server.py b/server.py index f0587c6..9672d5c 100644 --- a/server.py +++ b/server.py @@ -1,26 +1,21 @@ -from websocket_server import WebsocketServer +import logging -# Called for every client connecting (after handshake) -def new_client(client, server): - print("New client connected and was given id %d" % client['id']) - server.send_message_to_all("Hey all, a new client has joined us") +import websocket_server +logging.basicConfig(level=logging.INFO) -# Called for every client disconnecting -def client_left(client, server): - print("Client(%d) disconnected" % client['id']) +wss = websocket_server.WebSocketServer('127.0.0.1', 5005) +@wss.route('/ws/echo') +def handle_ws_echo(handler): + seq = 0 + while True: + msg = handler.read_next_message() + if not msg: + break + msg_type, msg_text = msg + msg = 'echo %s: %s' % (seq, msg_text) + seq += 1 + handler.send_message(msg) -# Called when a client sends a message -def message_received(client, server, message): - if len(message) > 200: - message = message[:200]+'..' - print("Client(%d) said: %s" % (client['id'], message)) - - -PORT=9001 -server = WebsocketServer(PORT) -server.set_fn_new_client(new_client) -server.set_fn_client_left(client_left) -server.set_fn_message_received(message_received) -server.run_forever() +wss.serve_forever() From 014a962234577902a5e52531e791953248219bea Mon Sep 17 00:00:00 2001 From: Zach Wegner Date: Thu, 5 Sep 2019 19:09:56 -0500 Subject: [PATCH 7/7] Move websocket routing functionality to a WebSocketRouter class, to allow multiple modules to configure their own routes that can be added to the server later (a la Flask Blueprints) --- websocket_server/websocket_server.py | 30 ++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index ebae6c7..9e31e02 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -261,23 +261,37 @@ def calculate_response_key(cls, key): response_key = b64encode(hash.digest()).strip() return response_key.decode('ASCII') +class WebSocketRouter: + def __init__(self, root_path=''): + self.root_path = root_path + self.routing_table = {} + + # Decorator for routing websocket handlers a la Flask + def route(self, path): + def decorate(fn): + self.routing_table['%s%s' % (self.root_path, path)] = fn + return fn + return decorate + + def add_router(self, router): + for path, fn in router.routing_table.items(): + self.routing_table['%s%s' % (self.root_path, path)] = fn + class WebSocketServer(ThreadingMixIn, TCPServer): allow_reuse_address = True daemon_threads = True - def __init__(self, host, port): + def __init__(self, host, port, root_path=''): address = (host, port) # Have to init in both superclasses? Weird... super(TCPServer, self).__init__(address, WebSocketHandler) super(ThreadingMixIn, self).__init__(address, WebSocketHandler) - self.routing_table = {} + self.router = WebSocketRouter(root_path=root_path) + # HACK-ish + self.routing_table = self.router.routing_table - # Decorator for routing websocket handlers a la Flask - def route(self, path): - def decorate(fn): - self.routing_table[path] = fn - return fn - return decorate + def add_router(self, router): + self.router.add_router(router) def start(self): threading.Thread(target=self.serve_forever, daemon=True).start()