diff --git a/client.html b/client.html index 598a6a6..04f8e38 100644 --- a/client.html +++ b/client.html @@ -9,7 +9,7 @@ function init() { // Connect to Web Socket - ws = new WebSocket("ws://localhost:9001/"); + ws = new WebSocket("ws://127.0.0.1:5005/ws/echo"); // Set event handlers. ws.onopen = function() { diff --git a/server.py b/server.py index f0587c6..9672d5c 100644 --- a/server.py +++ b/server.py @@ -1,26 +1,21 @@ -from websocket_server import WebsocketServer +import logging -# Called for every client connecting (after handshake) -def new_client(client, server): - print("New client connected and was given id %d" % client['id']) - server.send_message_to_all("Hey all, a new client has joined us") +import websocket_server +logging.basicConfig(level=logging.INFO) -# Called for every client disconnecting -def client_left(client, server): - print("Client(%d) disconnected" % client['id']) +wss = websocket_server.WebSocketServer('127.0.0.1', 5005) +@wss.route('/ws/echo') +def handle_ws_echo(handler): + seq = 0 + while True: + msg = handler.read_next_message() + if not msg: + break + msg_type, msg_text = msg + msg = 'echo %s: %s' % (seq, msg_text) + seq += 1 + handler.send_message(msg) -# Called when a client sends a message -def message_received(client, server, message): - if len(message) > 200: - message = message[:200]+'..' - print("Client(%d) said: %s" % (client['id'], message)) - - -PORT=9001 -server = WebsocketServer(PORT) -server.set_fn_new_client(new_client) -server.set_fn_client_left(client_left) -server.set_fn_message_received(message_received) -server.run_forever() +wss.serve_forever() diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index 96c658b..9e31e02 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -1,13 +1,14 @@ # Author: Johan Hanssen Seferidis # License: MIT +import errno +import logging import sys import struct +import threading from base64 import b64encode from hashlib import sha1 -import logging from socket import error as SocketError -import errno if sys.version_info[0] < 3: from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler @@ -15,23 +16,38 @@ from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler logger = logging.getLogger(__name__) -logging.basicConfig() - -''' -+-+-+-+-+-------+-+-------------+-------------------------------+ - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -+-+-+-+-+-------+-+-------------+-------------------------------+ -|F|R|R|R| opcode|M| Payload len | Extended payload length | -|I|S|S|S| (4) |A| (7) | (16/64) | -|N|V|V|V| |S| | (if payload len==126/127) | -| |1|2|3| |K| | | -+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + -| Extended payload length continued, if payload len == 127 | -+ - - - - - - - - - - - - - - - +-------------------------------+ -| Payload Data continued ... | -+---------------------------------------------------------------+ -''' + +def encode_to_UTF8(data): + try: + return data.encode('UTF-8') + except UnicodeEncodeError as e: + logger.error("Could not encode data to UTF-8 -- %s" % e) + return False + except Exception as e: + raise(e) + return False + +def try_decode_UTF8(data): + try: + return data.decode('utf-8') + except UnicodeDecodeError: + return False + except Exception as e: + raise(e) + +# +-+-+-+-+-------+-+-------------+-------------------------------+ +# |0| | | | | | 1 | 2 3 | +# |0|1|2|3|4 5 6 7|8|9 0 1 2 3 4 5|6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1| +# +-+-+-+-+-------+-+-------------+-------------------------------+ +# |F|R|R|R| opcode|M| Payload len | Extended payload length | +# |I|S|S|S| (4) |A| (7) | (16/64) | +# |N|V|V|V| |S| | (if payload len==126/127) | +# | |1|2|3| |K| | | +# +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + +# | Extended payload length continued, if payload len == 127 | +# + - - - - - - - - - - - - - - - +-------------------------------+ +# | Payload Data continued ... | +# +---------------------------------------------------------------+ FIN = 0x80 OPCODE = 0x0f @@ -47,122 +63,7 @@ OPCODE_PING = 0x9 OPCODE_PONG = 0xA - -# -------------------------------- API --------------------------------- - -class API(): - - def run_forever(self): - try: - logger.info("Listening on port %d for clients.." % self.port) - self.serve_forever() - except KeyboardInterrupt: - self.server_close() - logger.info("Server terminated.") - except Exception as e: - logger.error(str(e), exc_info=True) - exit(1) - - def new_client(self, client, server): - pass - - def client_left(self, client, server): - pass - - def message_received(self, client, server, message): - pass - - def set_fn_new_client(self, fn): - self.new_client = fn - - def set_fn_client_left(self, fn): - self.client_left = fn - - def set_fn_message_received(self, fn): - self.message_received = fn - - def send_message(self, client, msg): - self._unicast_(client, msg) - - def send_message_to_all(self, msg): - self._multicast_(msg) - - -# ------------------------- Implementation ----------------------------- - -class WebsocketServer(ThreadingMixIn, TCPServer, API): - """ - A websocket server waiting for clients to connect. - - Args: - port(int): Port to bind to - host(str): Hostname or IP to listen for connections. By default 127.0.0.1 - is being used. To accept connections from any client, you should use - 0.0.0.0. - loglevel: Logging level from logging module to use for logging. By default - warnings and errors are being logged. - - Properties: - clients(list): A list of connected clients. A client is a dictionary - like below. - { - 'id' : id, - 'handler' : handler, - 'address' : (addr, port) - } - """ - - allow_reuse_address = True - daemon_threads = True # comment to keep threads alive until finished - - clients = [] - id_counter = 0 - - def __init__(self, port, host='127.0.0.1', loglevel=logging.WARNING): - logger.setLevel(loglevel) - TCPServer.__init__(self, (host, port), WebSocketHandler) - self.port = self.socket.getsockname()[1] - - def _message_received_(self, handler, msg): - self.message_received(self.handler_to_client(handler), self, msg) - - def _ping_received_(self, handler, msg): - handler.send_pong(msg) - - def _pong_received_(self, handler, msg): - pass - - def _new_client_(self, handler): - self.id_counter += 1 - client = { - 'id': self.id_counter, - 'handler': handler, - 'address': handler.client_address - } - self.clients.append(client) - self.new_client(client, self) - - def _client_left_(self, handler): - client = self.handler_to_client(handler) - self.client_left(client, self) - if client in self.clients: - self.clients.remove(client) - - def _unicast_(self, to_client, msg): - to_client['handler'].send_message(msg) - - def _multicast_(self, msg): - for client in self.clients: - self._unicast_(client, msg) - - def handler_to_client(self, handler): - for client in self.clients: - if client['handler'] == handler: - return client - - class WebSocketHandler(StreamRequestHandler): - def __init__(self, socket, addr, server): self.server = server StreamRequestHandler.__init__(self, socket, addr, server) @@ -174,11 +75,19 @@ def setup(self): self.valid_client = False def handle(self): - while self.keep_alive: + try: if not self.handshake_done: - self.handshake() - elif self.valid_client: - self.read_next_message() + path = self.handshake() + + if path not in self.server.routing_table: + logger.warning('Bad path for websocket request: %s' % path) + return + route_fn = self.server.routing_table[path] + logger.info('websocket client connected to %s', path) + route_fn(self) + logger.info('websocket client disconnected from %s', path) + except BrokenPipeError as e: + logger.info("Client connection broken.") def read_bytes(self, num): # python3 gives ordinal of byte directly @@ -189,11 +98,13 @@ def read_bytes(self, num): return bytes def read_next_message(self): + if not self.keep_alive: + return None try: b1, b2 = self.read_bytes(2) except SocketError as e: # to be replaced with ConnectionResetError for py3 if e.errno == errno.ECONNRESET: - logger.info("Client closed connection.") + logger.debug("Client closed connection.") self.keep_alive = 0 return b1, b2 = 0, 0 @@ -206,7 +117,7 @@ def read_next_message(self): payload_length = b2 & PAYLOAD_LEN if opcode == OPCODE_CLOSE_CONN: - logger.info("Client asked to close connection.") + logger.debug("Client asked to close connection.") self.keep_alive = 0 return if not masked: @@ -220,11 +131,11 @@ def read_next_message(self): logger.warn("Binary frames are not supported.") return elif opcode == OPCODE_TEXT: - opcode_handler = self.server._message_received_ + opcode_type = 'text' elif opcode == OPCODE_PING: - opcode_handler = self.server._ping_received_ + opcode_type = 'ping' elif opcode == OPCODE_PONG: - opcode_handler = self.server._pong_received_ + opcode_type = 'pong' else: logger.warn("Unknown opcode %#x." % opcode) self.keep_alive = 0 @@ -240,13 +151,16 @@ def read_next_message(self): for message_byte in self.read_bytes(payload_length): message_byte ^= masks[len(message_bytes) % 4] message_bytes.append(message_byte) - opcode_handler(self, message_bytes.decode('utf8')) + return (opcode_type, message_bytes.decode('utf8')) def send_message(self, message): self.send_text(message) + def send_ping(self, message): + self.send_text(message, opcode=OPCODE_PING) + def send_pong(self, message): - self.send_text(message, OPCODE_PONG) + self.send_text(message, opcode=OPCODE_PONG) def send_text(self, message, opcode=OPCODE_TEXT): """ @@ -298,8 +212,9 @@ def send_text(self, message, opcode=OPCODE_TEXT): def read_http_headers(self): headers = {} # first line should be HTTP GET - http_get = self.rfile.readline().decode().strip() - assert http_get.upper().startswith('GET') + request = self.rfile.readline().decode().strip() + method, path, protocol = request.split() + assert method.startswith('GET') # remaining should be headers while True: header = self.rfile.readline().decode().strip() @@ -307,10 +222,10 @@ def read_http_headers(self): break head, value = header.split(':', 1) headers[head.lower().strip()] = value.strip() - return headers + return path, headers def handshake(self): - headers = self.read_http_headers() + path, headers = self.read_http_headers() try: assert headers['upgrade'].lower() == 'websocket' @@ -328,7 +243,7 @@ def handshake(self): response = self.make_handshake_response(key) self.handshake_done = self.request.send(response.encode()) self.valid_client = True - self.server._new_client_(self) + return path @classmethod def make_handshake_response(cls, key): @@ -346,25 +261,37 @@ def calculate_response_key(cls, key): response_key = b64encode(hash.digest()).strip() return response_key.decode('ASCII') - def finish(self): - self.server._client_left_(self) +class WebSocketRouter: + def __init__(self, root_path=''): + self.root_path = root_path + self.routing_table = {} + # Decorator for routing websocket handlers a la Flask + def route(self, path): + def decorate(fn): + self.routing_table['%s%s' % (self.root_path, path)] = fn + return fn + return decorate -def encode_to_UTF8(data): - try: - return data.encode('UTF-8') - except UnicodeEncodeError as e: - logger.error("Could not encode data to UTF-8 -- %s" % e) - return False - except Exception as e: - raise(e) - return False + def add_router(self, router): + for path, fn in router.routing_table.items(): + self.routing_table['%s%s' % (self.root_path, path)] = fn - -def try_decode_UTF8(data): - try: - return data.decode('utf-8') - except UnicodeDecodeError: - return False - except Exception as e: - raise(e) +class WebSocketServer(ThreadingMixIn, TCPServer): + allow_reuse_address = True + daemon_threads = True + + def __init__(self, host, port, root_path=''): + address = (host, port) + # Have to init in both superclasses? Weird... + super(TCPServer, self).__init__(address, WebSocketHandler) + super(ThreadingMixIn, self).__init__(address, WebSocketHandler) + self.router = WebSocketRouter(root_path=root_path) + # HACK-ish + self.routing_table = self.router.routing_table + + def add_router(self, router): + self.router.add_router(router) + + def start(self): + threading.Thread(target=self.serve_forever, daemon=True).start()