diff --git a/README.md b/README.md index 22d4b9b..c0bbb3d 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ The WebsocketServer can be initialized with the below parameters. | `set_fn_new_client()` | Sets a callback function that will be called for every new `client` connecting to us | function | None | | `set_fn_client_left()` | Sets a callback function that will be called for every `client` disconnecting from us | function | None | | `set_fn_message_received()` | Sets a callback function that will be called when a `client` sends a message | function | None | +| `set_fn_authenticate()` | Sets a callback function that will be called during the handshake | function | None | | `send_message()` | Sends a `message` to a specific `client`. The message is a simple string. | client, message | None | | `send_message_to_all()` | Sends a `message` to **all** connected clients. The message is a simple string. | message | None | @@ -75,6 +76,7 @@ The WebsocketServer can be initialized with the below parameters. | `set_fn_new_client()` | Called for every new `client` connecting to us | client, server | | `set_fn_client_left()` | Called for every `client` disconnecting from us | client, server | | `set_fn_message_received()` | Called when a `client` sends a `message` | client, server, message | +| `set_fn_authenticate()` | Called every time when a handshake is triggered | message | The client passed to the callback is the client that left, sent the message, etc. The server might not have any use to use. However it is passed in case you want to send messages to clients. diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index dd0af82..f8d8097 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -49,7 +49,7 @@ # -------------------------------- API --------------------------------- -class API(): +class API(object): def run_forever(self): try: @@ -71,6 +71,9 @@ def client_left(self, client, server): def message_received(self, client, server, message): pass + def authenticate(self, msg): + return True + def set_fn_new_client(self, fn): self.new_client = fn @@ -80,6 +83,9 @@ def set_fn_client_left(self, fn): def set_fn_message_received(self, fn): self.message_received = fn + def set_fn_authenticate(self, fn): + self.authenticate = fn + def send_message(self, client, msg): self._unicast_(client, msg) @@ -147,6 +153,9 @@ def _client_left_(self, handler): if client in self.clients: self.clients.remove(client) + def _authenticate_(self, msg): + return self.authenticate(msg) + def _unicast_(self, to_client, msg): to_client['handler'].send_message(msg) @@ -160,7 +169,7 @@ def handler_to_client(self, handler): return client -class WebSocketHandler(StreamRequestHandler): +class WebSocketHandler(StreamRequestHandler, object): def __init__(self, socket, addr, server): self.server = server @@ -303,6 +312,11 @@ def handshake(self): logger.warning("Client tried to connect but was missing a key") self.keep_alive = False return + if not self.server._authenticate_(message): + logger.warning("Failed to authenticate a client.") + self.request.send(self.make_unauthorized_response().encode()) + self.keep_alive = False + return response = self.make_handshake_response(key) self.handshake_done = self.request.send(response.encode()) self.valid_client = True @@ -316,6 +330,9 @@ def make_handshake_response(self, key): 'Sec-WebSocket-Accept: %s\r\n' \ '\r\n' % self.calculate_response_key(key) + def make_unauthorized_response(self): + return 'HTTP/1.1 403 Forbidden\r\n\r\n' + def calculate_response_key(self, key): GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' hash = sha1(key.encode() + GUID.encode())