From 7e9f93080ecbbecc6b67172043afab9cd9e4de99 Mon Sep 17 00:00:00 2001 From: Logan McNaughton <848146+loganmc10@users.noreply.github.com> Date: Fri, 5 Sep 2025 09:06:18 +0200 Subject: [PATCH] allow auth during join room (#150) * allow auth during join room * linter --- internal/lobbyServer/lobby.go | 64 ++++++++++++++--------------------- 1 file changed, 25 insertions(+), 39 deletions(-) diff --git a/internal/lobbyServer/lobby.go b/internal/lobbyServer/lobby.go index 30a9c85..a13a36a 100644 --- a/internal/lobbyServer/lobby.go +++ b/internal/lobbyServer/lobby.go @@ -245,15 +245,21 @@ func (s *LobbyServer) watchGameServer(name string, g *gameserver.GameServer) { } } -func (s *LobbyServer) validateAuth(receivedMessage SocketMessage) error { +func (s *LobbyServer) validateAuth(receivedMessage SocketMessage) (int, error) { + if receivedMessage.NetplayVersion != NetplayAPIVersion { + return MismatchVersion, fmt.Errorf("client and server not at same API version, please update your emulator") + } else if receivedMessage.Emulator == "" { + return BadEmulator, fmt.Errorf("emulator name cannot be empty") + } + if !s.EnableAuth { - return nil + return Accepted, nil } now := time.Now().UTC() timeAsInt, err := strconv.ParseInt(receivedMessage.AuthTime, 10, 64) if err != nil { - return fmt.Errorf("could not parse time for authentication") + return BadAuth, fmt.Errorf("could not parse time for authentication") } receivedTime := time.UnixMilli(timeAsInt).UTC() @@ -262,7 +268,7 @@ func (s *LobbyServer) validateAuth(receivedMessage SocketMessage) error { maxAllowableDifference := 15 * time.Minute if absTimeDifference > maxAllowableDifference { - return fmt.Errorf("clock skew detected, please check your system time") + return BadAuth, fmt.Errorf("clock skew detected, please check your system time") } h := sha256.New() @@ -270,14 +276,14 @@ func (s *LobbyServer) validateAuth(receivedMessage SocketMessage) error { authCode := os.Getenv(fmt.Sprintf("%s_AUTH", strings.ToUpper(receivedMessage.Emulator))) if authCode == "" { - return fmt.Errorf("no authentication code found for emulator %s", receivedMessage.Emulator) + return BadAuth, fmt.Errorf("no authentication code found for emulator %s", receivedMessage.Emulator) } h.Write([]byte(authCode)) if receivedMessage.Auth == hex.EncodeToString(h.Sum(nil)) { - return nil + return Accepted, nil } else { - return fmt.Errorf("bad authentication code") + return BadAuth, fmt.Errorf("bad authentication code") } } @@ -340,12 +346,6 @@ func (s *LobbyServer) wsHandler(w http.ResponseWriter, r *http.Request) { if err := s.sendData(ws, sendMessage); err != nil { s.Logger.Error(err, "failed to send message", "message", sendMessage, "address", ws.RemoteAddr()) } - } else if receivedMessage.NetplayVersion != NetplayAPIVersion { - sendMessage.Accept = MismatchVersion - sendMessage.Message = "Client and server not at same API version. Please update your emulator" - if err := s.sendData(ws, sendMessage); err != nil { - s.Logger.Error(err, "failed to send message", "message", sendMessage, "address", ws.RemoteAddr()) - } } else if receivedMessage.Room.RoomName == "" { sendMessage.Accept = BadName sendMessage.Message = "Room name cannot be empty" @@ -358,16 +358,10 @@ func (s *LobbyServer) wsHandler(w http.ResponseWriter, r *http.Request) { if err := s.sendData(ws, sendMessage); err != nil { s.Logger.Error(err, "failed to send message", "message", sendMessage, "address", ws.RemoteAddr()) } - } else if receivedMessage.Emulator == "" { - sendMessage.Accept = BadEmulator - sendMessage.Message = "Emulator name cannot be empty" - if err := s.sendData(ws, sendMessage); err != nil { - s.Logger.Error(err, "failed to send message", "message", sendMessage, "address", ws.RemoteAddr()) - } - } else if authErr := s.validateAuth(receivedMessage); authErr != nil { - sendMessage.Accept = BadAuth + } else if acceptValue, authErr := s.validateAuth(receivedMessage); authErr != nil { + sendMessage.Accept = acceptValue sendMessage.Message = authErr.Error() - s.Logger.Info("bad auth code", "authError", authErr.Error(), "message", receivedMessage, "address", ws.RemoteAddr()) + s.Logger.Info("bad auth", "authError", authErr.Error(), "message", receivedMessage, "address", ws.RemoteAddr()) if err := s.sendData(ws, sendMessage); err != nil { s.Logger.Error(err, "failed to send message", "message", sendMessage, "address", ws.RemoteAddr()) } @@ -466,22 +460,10 @@ func (s *LobbyServer) wsHandler(w http.ResponseWriter, r *http.Request) { } } else if receivedMessage.Type == TypeRequestGetRooms { sendMessage.Type = TypeReplyGetRooms - if receivedMessage.NetplayVersion != NetplayAPIVersion { - sendMessage.Accept = MismatchVersion - sendMessage.Message = "Client and server not at same API version. Please update your emulator" - if err := s.sendData(ws, sendMessage); err != nil { - s.Logger.Error(err, "failed to send message", "message", sendMessage, "address", ws.RemoteAddr()) - } - } else if receivedMessage.Emulator == "" { - sendMessage.Accept = BadEmulator - sendMessage.Message = "Emulator name cannot be empty" - if err := s.sendData(ws, sendMessage); err != nil { - s.Logger.Error(err, "failed to send message", "message", sendMessage, "address", ws.RemoteAddr()) - } - } else if authErr := s.validateAuth(receivedMessage); authErr != nil { - sendMessage.Accept = BadAuth + if acceptValue, authErr := s.validateAuth(receivedMessage); authErr != nil { + sendMessage.Accept = acceptValue sendMessage.Message = authErr.Error() - s.Logger.Info("bad auth code", "authError", authErr.Error(), "message", receivedMessage, "address", ws.RemoteAddr()) + s.Logger.Info("bad auth", "authError", authErr.Error(), "message", receivedMessage, "address", ws.RemoteAddr()) if err := s.sendData(ws, sendMessage); err != nil { s.Logger.Error(err, "failed to send message", "message", sendMessage, "address", ws.RemoteAddr()) } @@ -517,8 +499,12 @@ func (s *LobbyServer) wsHandler(w http.ResponseWriter, r *http.Request) { } } else if receivedMessage.Type == TypeRequestJoinRoom { if !authenticated { - s.Logger.Error(fmt.Errorf("bad auth"), "User tried to join room without being authenticated", "address", ws.RemoteAddr()) - continue + if _, authErr := s.validateAuth(receivedMessage); authErr != nil { + s.Logger.Error(fmt.Errorf("bad auth"), "User tried to join room without being authenticated", "address", ws.RemoteAddr()) + continue + } else { + authenticated = true + } } var duplicateName bool var accepted int