diff --git a/client.html b/client.html
index 598a6a6..04f8e38 100644
--- a/client.html
+++ b/client.html
@@ -9,7 +9,7 @@
function init() {
// Connect to Web Socket
- ws = new WebSocket("ws://localhost:9001/");
+ ws = new WebSocket("ws://127.0.0.1:5005/ws/echo");
// Set event handlers.
ws.onopen = function() {
diff --git a/server.py b/server.py
index f0587c6..9672d5c 100644
--- a/server.py
+++ b/server.py
@@ -1,26 +1,21 @@
-from websocket_server import WebsocketServer
+import logging
-# Called for every client connecting (after handshake)
-def new_client(client, server):
- print("New client connected and was given id %d" % client['id'])
- server.send_message_to_all("Hey all, a new client has joined us")
+import websocket_server
+logging.basicConfig(level=logging.INFO)
-# Called for every client disconnecting
-def client_left(client, server):
- print("Client(%d) disconnected" % client['id'])
+wss = websocket_server.WebSocketServer('127.0.0.1', 5005)
+@wss.route('/ws/echo')
+def handle_ws_echo(handler):
+ seq = 0
+ while True:
+ msg = handler.read_next_message()
+ if not msg:
+ break
+ msg_type, msg_text = msg
+ msg = 'echo %s: %s' % (seq, msg_text)
+ seq += 1
+ handler.send_message(msg)
-# Called when a client sends a message
-def message_received(client, server, message):
- if len(message) > 200:
- message = message[:200]+'..'
- print("Client(%d) said: %s" % (client['id'], message))
-
-
-PORT=9001
-server = WebsocketServer(PORT)
-server.set_fn_new_client(new_client)
-server.set_fn_client_left(client_left)
-server.set_fn_message_received(message_received)
-server.run_forever()
+wss.serve_forever()
diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py
index 96c658b..9e31e02 100644
--- a/websocket_server/websocket_server.py
+++ b/websocket_server/websocket_server.py
@@ -1,13 +1,14 @@
# Author: Johan Hanssen Seferidis
# License: MIT
+import errno
+import logging
import sys
import struct
+import threading
from base64 import b64encode
from hashlib import sha1
-import logging
from socket import error as SocketError
-import errno
if sys.version_info[0] < 3:
from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler
@@ -15,23 +16,38 @@
from socketserver import ThreadingMixIn, TCPServer, StreamRequestHandler
logger = logging.getLogger(__name__)
-logging.basicConfig()
-
-'''
-+-+-+-+-+-------+-+-------------+-------------------------------+
- 0 1 2 3
- 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
-+-+-+-+-+-------+-+-------------+-------------------------------+
-|F|R|R|R| opcode|M| Payload len | Extended payload length |
-|I|S|S|S| (4) |A| (7) | (16/64) |
-|N|V|V|V| |S| | (if payload len==126/127) |
-| |1|2|3| |K| | |
-+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
-| Extended payload length continued, if payload len == 127 |
-+ - - - - - - - - - - - - - - - +-------------------------------+
-| Payload Data continued ... |
-+---------------------------------------------------------------+
-'''
+
+def encode_to_UTF8(data):
+ try:
+ return data.encode('UTF-8')
+ except UnicodeEncodeError as e:
+ logger.error("Could not encode data to UTF-8 -- %s" % e)
+ return False
+ except Exception as e:
+ raise(e)
+ return False
+
+def try_decode_UTF8(data):
+ try:
+ return data.decode('utf-8')
+ except UnicodeDecodeError:
+ return False
+ except Exception as e:
+ raise(e)
+
+# +-+-+-+-+-------+-+-------------+-------------------------------+
+# |0| | | | | | 1 | 2 3 |
+# |0|1|2|3|4 5 6 7|8|9 0 1 2 3 4 5|6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1|
+# +-+-+-+-+-------+-+-------------+-------------------------------+
+# |F|R|R|R| opcode|M| Payload len | Extended payload length |
+# |I|S|S|S| (4) |A| (7) | (16/64) |
+# |N|V|V|V| |S| | (if payload len==126/127) |
+# | |1|2|3| |K| | |
+# +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+# | Extended payload length continued, if payload len == 127 |
+# + - - - - - - - - - - - - - - - +-------------------------------+
+# | Payload Data continued ... |
+# +---------------------------------------------------------------+
FIN = 0x80
OPCODE = 0x0f
@@ -47,122 +63,7 @@
OPCODE_PING = 0x9
OPCODE_PONG = 0xA
-
-# -------------------------------- API ---------------------------------
-
-class API():
-
- def run_forever(self):
- try:
- logger.info("Listening on port %d for clients.." % self.port)
- self.serve_forever()
- except KeyboardInterrupt:
- self.server_close()
- logger.info("Server terminated.")
- except Exception as e:
- logger.error(str(e), exc_info=True)
- exit(1)
-
- def new_client(self, client, server):
- pass
-
- def client_left(self, client, server):
- pass
-
- def message_received(self, client, server, message):
- pass
-
- def set_fn_new_client(self, fn):
- self.new_client = fn
-
- def set_fn_client_left(self, fn):
- self.client_left = fn
-
- def set_fn_message_received(self, fn):
- self.message_received = fn
-
- def send_message(self, client, msg):
- self._unicast_(client, msg)
-
- def send_message_to_all(self, msg):
- self._multicast_(msg)
-
-
-# ------------------------- Implementation -----------------------------
-
-class WebsocketServer(ThreadingMixIn, TCPServer, API):
- """
- A websocket server waiting for clients to connect.
-
- Args:
- port(int): Port to bind to
- host(str): Hostname or IP to listen for connections. By default 127.0.0.1
- is being used. To accept connections from any client, you should use
- 0.0.0.0.
- loglevel: Logging level from logging module to use for logging. By default
- warnings and errors are being logged.
-
- Properties:
- clients(list): A list of connected clients. A client is a dictionary
- like below.
- {
- 'id' : id,
- 'handler' : handler,
- 'address' : (addr, port)
- }
- """
-
- allow_reuse_address = True
- daemon_threads = True # comment to keep threads alive until finished
-
- clients = []
- id_counter = 0
-
- def __init__(self, port, host='127.0.0.1', loglevel=logging.WARNING):
- logger.setLevel(loglevel)
- TCPServer.__init__(self, (host, port), WebSocketHandler)
- self.port = self.socket.getsockname()[1]
-
- def _message_received_(self, handler, msg):
- self.message_received(self.handler_to_client(handler), self, msg)
-
- def _ping_received_(self, handler, msg):
- handler.send_pong(msg)
-
- def _pong_received_(self, handler, msg):
- pass
-
- def _new_client_(self, handler):
- self.id_counter += 1
- client = {
- 'id': self.id_counter,
- 'handler': handler,
- 'address': handler.client_address
- }
- self.clients.append(client)
- self.new_client(client, self)
-
- def _client_left_(self, handler):
- client = self.handler_to_client(handler)
- self.client_left(client, self)
- if client in self.clients:
- self.clients.remove(client)
-
- def _unicast_(self, to_client, msg):
- to_client['handler'].send_message(msg)
-
- def _multicast_(self, msg):
- for client in self.clients:
- self._unicast_(client, msg)
-
- def handler_to_client(self, handler):
- for client in self.clients:
- if client['handler'] == handler:
- return client
-
-
class WebSocketHandler(StreamRequestHandler):
-
def __init__(self, socket, addr, server):
self.server = server
StreamRequestHandler.__init__(self, socket, addr, server)
@@ -174,11 +75,19 @@ def setup(self):
self.valid_client = False
def handle(self):
- while self.keep_alive:
+ try:
if not self.handshake_done:
- self.handshake()
- elif self.valid_client:
- self.read_next_message()
+ path = self.handshake()
+
+ if path not in self.server.routing_table:
+ logger.warning('Bad path for websocket request: %s' % path)
+ return
+ route_fn = self.server.routing_table[path]
+ logger.info('websocket client connected to %s', path)
+ route_fn(self)
+ logger.info('websocket client disconnected from %s', path)
+ except BrokenPipeError as e:
+ logger.info("Client connection broken.")
def read_bytes(self, num):
# python3 gives ordinal of byte directly
@@ -189,11 +98,13 @@ def read_bytes(self, num):
return bytes
def read_next_message(self):
+ if not self.keep_alive:
+ return None
try:
b1, b2 = self.read_bytes(2)
except SocketError as e: # to be replaced with ConnectionResetError for py3
if e.errno == errno.ECONNRESET:
- logger.info("Client closed connection.")
+ logger.debug("Client closed connection.")
self.keep_alive = 0
return
b1, b2 = 0, 0
@@ -206,7 +117,7 @@ def read_next_message(self):
payload_length = b2 & PAYLOAD_LEN
if opcode == OPCODE_CLOSE_CONN:
- logger.info("Client asked to close connection.")
+ logger.debug("Client asked to close connection.")
self.keep_alive = 0
return
if not masked:
@@ -220,11 +131,11 @@ def read_next_message(self):
logger.warn("Binary frames are not supported.")
return
elif opcode == OPCODE_TEXT:
- opcode_handler = self.server._message_received_
+ opcode_type = 'text'
elif opcode == OPCODE_PING:
- opcode_handler = self.server._ping_received_
+ opcode_type = 'ping'
elif opcode == OPCODE_PONG:
- opcode_handler = self.server._pong_received_
+ opcode_type = 'pong'
else:
logger.warn("Unknown opcode %#x." % opcode)
self.keep_alive = 0
@@ -240,13 +151,16 @@ def read_next_message(self):
for message_byte in self.read_bytes(payload_length):
message_byte ^= masks[len(message_bytes) % 4]
message_bytes.append(message_byte)
- opcode_handler(self, message_bytes.decode('utf8'))
+ return (opcode_type, message_bytes.decode('utf8'))
def send_message(self, message):
self.send_text(message)
+ def send_ping(self, message):
+ self.send_text(message, opcode=OPCODE_PING)
+
def send_pong(self, message):
- self.send_text(message, OPCODE_PONG)
+ self.send_text(message, opcode=OPCODE_PONG)
def send_text(self, message, opcode=OPCODE_TEXT):
"""
@@ -298,8 +212,9 @@ def send_text(self, message, opcode=OPCODE_TEXT):
def read_http_headers(self):
headers = {}
# first line should be HTTP GET
- http_get = self.rfile.readline().decode().strip()
- assert http_get.upper().startswith('GET')
+ request = self.rfile.readline().decode().strip()
+ method, path, protocol = request.split()
+ assert method.startswith('GET')
# remaining should be headers
while True:
header = self.rfile.readline().decode().strip()
@@ -307,10 +222,10 @@ def read_http_headers(self):
break
head, value = header.split(':', 1)
headers[head.lower().strip()] = value.strip()
- return headers
+ return path, headers
def handshake(self):
- headers = self.read_http_headers()
+ path, headers = self.read_http_headers()
try:
assert headers['upgrade'].lower() == 'websocket'
@@ -328,7 +243,7 @@ def handshake(self):
response = self.make_handshake_response(key)
self.handshake_done = self.request.send(response.encode())
self.valid_client = True
- self.server._new_client_(self)
+ return path
@classmethod
def make_handshake_response(cls, key):
@@ -346,25 +261,37 @@ def calculate_response_key(cls, key):
response_key = b64encode(hash.digest()).strip()
return response_key.decode('ASCII')
- def finish(self):
- self.server._client_left_(self)
+class WebSocketRouter:
+ def __init__(self, root_path=''):
+ self.root_path = root_path
+ self.routing_table = {}
+ # Decorator for routing websocket handlers a la Flask
+ def route(self, path):
+ def decorate(fn):
+ self.routing_table['%s%s' % (self.root_path, path)] = fn
+ return fn
+ return decorate
-def encode_to_UTF8(data):
- try:
- return data.encode('UTF-8')
- except UnicodeEncodeError as e:
- logger.error("Could not encode data to UTF-8 -- %s" % e)
- return False
- except Exception as e:
- raise(e)
- return False
+ def add_router(self, router):
+ for path, fn in router.routing_table.items():
+ self.routing_table['%s%s' % (self.root_path, path)] = fn
-
-def try_decode_UTF8(data):
- try:
- return data.decode('utf-8')
- except UnicodeDecodeError:
- return False
- except Exception as e:
- raise(e)
+class WebSocketServer(ThreadingMixIn, TCPServer):
+ allow_reuse_address = True
+ daemon_threads = True
+
+ def __init__(self, host, port, root_path=''):
+ address = (host, port)
+ # Have to init in both superclasses? Weird...
+ super(TCPServer, self).__init__(address, WebSocketHandler)
+ super(ThreadingMixIn, self).__init__(address, WebSocketHandler)
+ self.router = WebSocketRouter(root_path=root_path)
+ # HACK-ish
+ self.routing_table = self.router.routing_table
+
+ def add_router(self, router):
+ self.router.add_router(router)
+
+ def start(self):
+ threading.Thread(target=self.serve_forever, daemon=True).start()