diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index 9d1af5c..7eca370 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -8,6 +8,7 @@ import logging from socket import error as SocketError import errno +from threading import Thread if sys.version_info[0] < 3: from SocketServer import ThreadingMixIn, TCPServer, StreamRequestHandler @@ -50,7 +51,12 @@ # -------------------------------- API --------------------------------- -class API(): +class API(Thread): + def __init__(self): + Thread.__init__(self) + + def run(self): + self.run_forever() def run_forever(self): try: @@ -61,7 +67,6 @@ def run_forever(self): logger.info("Server terminated.") except Exception as e: logger.error(str(e), exc_info=True) - exit(1) def new_client(self, client, server): pass @@ -119,6 +124,7 @@ class WebsocketServer(ThreadingMixIn, TCPServer, API): id_counter = 0 def __init__(self, port, host='127.0.0.1', loglevel=logging.WARNING): + API.__init__(self) logger.setLevel(loglevel) TCPServer.__init__(self, (host, port), WebSocketHandler) self.port = self.socket.getsockname()[1] @@ -313,6 +319,11 @@ def read_http_headers(self): def handshake(self): headers = self.read_http_headers() + if 'upgrade' not in headers.keys(): + self.request.send(self.make_upgrade_response().encode()) + self.keep_alive = False + return + try: assert headers['upgrade'].lower() == 'websocket' except AssertionError: @@ -331,6 +342,17 @@ def handshake(self): self.valid_client = True self.server._new_client_(self) + @classmethod + def make_upgrade_response(cls): + return \ + 'HTTP/1.1 426 Upgrade Required\r\n'\ + 'Upgrade: HTTP/3.0\r\n'\ + 'Connection: Close\r\n'\ + 'Content-Encoding: text/plain\r\n'\ + '\r\n'\ + 'This service requires use of the HTTP/3.0 protocol' + + @classmethod def make_handshake_response(cls, key): return \