diff --git a/server.py b/server.py index 210c884..c9a0634 100644 --- a/server.py +++ b/server.py @@ -1,26 +1,55 @@ from websocket_server import WebsocketServer +import sys +import threading +import time +import os +import fcntl +import select + + +sshim_client = None # Called for every client connecting (after handshake) def new_client(client, server): - print("New client connected and was given id %d" % client['id']) - server.send_message_to_all("Hey all, a new client has joined us") - + global sshim_client + if (sshim_client == None): + sshim_client = client # Called for every client disconnecting def client_left(client, server): - print("Client(%d) disconnected" % client['id']) - + global sshim_client + sshim_client = None # Called when a client sends a message def message_received(client, server, message): - if len(message) > 200: - message = message[:200]+'..' - print("Client(%d) said: %s" % (client['id'], message)) + sys.stdout.buffer.write(message) + sys.stdout.buffer.flush() +def thread_input(server): + global sshim_client + orig_fl = fcntl.fcntl(sys.stdin, fcntl.F_GETFL) + fcntl.fcntl(sys.stdin, fcntl.F_SETFL, orig_fl | os.O_NONBLOCK) + while True: + if sshim_client != None: + break + else: + time.sleep(4) + while True: + i = [sys.stdin] + ins, _, _ = select.select(i, [], [], 0) + if len(ins) != 0: + data = sys.stdin.buffer.read() + server.send_message_to_all(data) -PORT=9001 + +PORT=int(sys.argv[1]) server = WebsocketServer(port = PORT) server.set_fn_new_client(new_client) server.set_fn_client_left(client_left) server.set_fn_message_received(message_received) + + +x = threading.Thread(target=thread_input, args=(server,)) +x.start() + server.run_forever() diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index 083ee17..68be40f 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -315,8 +315,7 @@ def read_next_message(self): logger.warning("Continuation frames are not supported.") return elif opcode == OPCODE_BINARY: - logger.warning("Binary frames are not supported.") - return + opcode_handler = self.server._message_received_ elif opcode == OPCODE_TEXT: opcode_handler = self.server._message_received_ elif opcode == OPCODE_PING: @@ -338,7 +337,7 @@ def read_next_message(self): for message_byte in self.read_bytes(payload_length): message_byte ^= masks[len(message_bytes) % 4] message_bytes.append(message_byte) - opcode_handler(self, message_bytes.decode('utf8')) + opcode_handler(self, message_bytes) def send_message(self, message): self.send_text(message) @@ -375,17 +374,17 @@ def send_text(self, message, opcode=OPCODE_TEXT): """ # Validate message - if isinstance(message, bytes): - message = try_decode_UTF8(message) # this is slower but ensures we have UTF-8 - if not message: - logger.warning("Can\'t send message, message is not valid UTF-8") - return False - elif not isinstance(message, str): - logger.warning('Can\'t send message, message has to be a string or bytes. Got %s' % type(message)) - return False + #if isinstance(message, bytes): + #message = try_decode_UTF8(message) # this is slower but ensures we have UTF-8 + #if not message: + # logger.warning("Can\'t send message, message is not valid UTF-8") + # return False + #elif not isinstance(message, str): + # logger.warning('Can\'t send message, message has to be a string or bytes. Got %s' % type(message)) + # return False header = bytearray() - payload = encode_to_UTF8(message) + payload = message payload_length = len(payload) # Normal payload