diff --git a/websocket_server/websocket_server.py b/websocket_server/websocket_server.py index 9d1af5c..7f81fe8 100644 --- a/websocket_server/websocket_server.py +++ b/websocket_server/websocket_server.py @@ -172,6 +172,8 @@ def setup(self): self.keep_alive = True self.handshake_done = False self.valid_client = False + self.fragment_opcode = 0 + self.fragment_payload_buf = bytearray() def handle(self): while self.keep_alive: @@ -214,22 +216,9 @@ def read_next_message(self): logger.warn("Client must always be masked.") self.keep_alive = 0 return - if opcode == OPCODE_CONTINUATION: - logger.warn("Continuation frames are not supported.") - return - elif opcode == OPCODE_BINARY: + if opcode == OPCODE_BINARY: logger.warn("Binary frames are not supported.") return - elif opcode == OPCODE_TEXT: - opcode_handler = self.server._message_received_ - elif opcode == OPCODE_PING: - opcode_handler = self.server._ping_received_ - elif opcode == OPCODE_PONG: - opcode_handler = self.server._pong_received_ - else: - logger.warn("Unknown opcode %#x." % opcode) - self.keep_alive = 0 - return if payload_length == 126: payload_length = struct.unpack(">H", self.rfile.read(2))[0] @@ -237,11 +226,36 @@ def read_next_message(self): payload_length = struct.unpack(">Q", self.rfile.read(8))[0] masks = self.read_bytes(4) - message_bytes = bytearray() + payload = bytearray() for message_byte in self.read_bytes(payload_length): - message_byte ^= masks[len(message_bytes) % 4] - message_bytes.append(message_byte) - opcode_handler(self, message_bytes.decode('utf8')) + message_byte ^= masks[len(payload) % 4] + payload.append(message_byte) + + if fin and opcode != OPCODE_CONTINUATION: # simple msg + if opcode == OPCODE_PING: + self.server._ping_received_(self, payload.decode('utf8')) + elif opcode == OPCODE_PONG: + self.server._pong_received_(self, payload.decode('utf8')) + elif opcode == OPCODE_TEXT: + self.server._message_received_(self, payload.decode('utf8')) + return + + if not fin and opcode: # fragment msg start + self.fragment_opcode = opcode + self.fragment_payload_buf = payload + return + + # "not opcode" is the same as "opcode == OPCODE_CONTINUATION" + if not fin and not opcode: # fragment msg ing + self.fragment_payload_buf.extend(payload) + return + + if fin and opcode == OPCODE_CONTINUATION: # fragment msg end + if self.fragment_opcode == OPCODE_TEXT: + self.server._message_received_(self, (self.fragment_payload_buf + payload).decode('utf8')) + elif self.fragment_opcode == OPCODE_BINARY: + pass + return def send_message(self, message): self.send_text(message)