diff --git a/Cargo.toml b/Cargo.toml index 98a5bcf..5208d83 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ glam = { version = "0.27.0", optional = true } serde = { version = "1.0.210", features = ["derive"] } thiserror = "1.0.50" planus = "1.0.0" +swap-buffer-queue = { version = "0.2.1", features = ["std"] } [features] default = ["glam"] diff --git a/examples/atba_agent/main.rs b/examples/atba_agent/main.rs index eda8fe7..bce1bc4 100644 --- a/examples/atba_agent/main.rs +++ b/examples/atba_agent/main.rs @@ -1,9 +1,9 @@ use std::{env, f32::consts::PI}; use rlbot_interface::{ - agents::{run_agents, Agent}, + agents::{run_agents, Agent, PacketQueue}, rlbot::{ConnectionSettings, ControllableInfo, ControllerState, PlayerInput}, - Packet, RLBotConnection, + RLBotConnection, }; struct AtbaAgent { @@ -14,17 +14,19 @@ impl Agent for AtbaAgent { fn new(controllable_info: ControllableInfo) -> Self { Self { controllable_info } } - fn tick(&mut self, game_packet: rlbot_interface::rlbot::GamePacket) -> Vec { - let mut packets_to_send = vec![]; - + fn tick( + &mut self, + game_packet: rlbot_interface::rlbot::GamePacket, + packet_queue: &mut PacketQueue, + ) { let Some(ball) = game_packet.balls.first() else { // If theres no ball, theres nothing to chase, don't do anything - return packets_to_send; + return; }; // We're not in the gtp, skip this tick if game_packet.players.len() <= self.controllable_info.index as usize { - return packets_to_send; + return; } let target = &ball.physics; @@ -58,14 +60,13 @@ impl Agent for AtbaAgent { controller.throttle = 1.; - packets_to_send.push( + packet_queue.push( PlayerInput { player_index: self.controllable_info.index, controller_state: controller, } .into(), ); - packets_to_send } } fn main() { diff --git a/src/agents.rs b/src/agents.rs index aa1e365..387254f 100644 --- a/src/agents.rs +++ b/src/agents.rs @@ -1,22 +1,32 @@ -use std::{collections::VecDeque, io::Write, mem, thread}; +use std::{ + collections::VecDeque, + io::Write, + mem, + sync::Arc, + thread::{self}, +}; + +use swap_buffer_queue::{buffer::VecBuffer, SynchronizedQueue}; use crate::{rlbot::*, Packet, RLBotConnection, RLBotError}; #[allow(unused_variables)] pub trait Agent { fn new(controllable_info: ControllableInfo) -> Self; - fn tick(&mut self, game_packet: GamePacket) -> Vec; - fn on_field_info(&mut self, field_info: FieldInfo) -> Vec { - vec![] - } - fn on_match_settings(&mut self, match_settings: MatchSettings) -> Vec { - vec![] - } - fn on_match_comm(&mut self, match_comm: MatchComm) -> Vec { - vec![] + fn tick(&mut self, game_packet: GamePacket, packet_queue: &mut PacketQueue) -> (); + fn on_field_info(&mut self, field_info: FieldInfo, packet_queue: &mut PacketQueue) -> () {} + fn on_match_settings( + &mut self, + match_settings: MatchSettings, + packet_queue: &mut PacketQueue, + ) -> () { } - fn on_ball_prediction(&mut self, ball_prediction: BallPrediction) -> Vec { - vec![] + fn on_match_comm(&mut self, match_comm: MatchComm, packet_queue: &mut PacketQueue) -> () {} + fn on_ball_prediction( + &mut self, + ball_prediction: BallPrediction, + packet_queue: &mut PacketQueue, + ) -> () { } } @@ -28,6 +38,25 @@ pub enum AgentError { PacketParseError(#[from] crate::RLBotError), } +/// A queue of packets to be sent to RLBotServer +pub struct PacketQueue { + internal_queue: Vec, +} + +impl PacketQueue { + pub fn new() -> Self { + PacketQueue { + internal_queue: Vec::with_capacity(16), + } + } + pub fn push(&mut self, packet: Packet) { + self.internal_queue.push(packet); + } + fn empty(&mut self) -> Vec { + mem::take(&mut self.internal_queue) + } +} + /// Run multiple agents on one thread each. They share a connection. /// Ok(()) means a successful exit; one of the bots received a None packet. pub fn run_agents( @@ -51,14 +80,21 @@ pub fn run_agents( let mut threads = vec![]; - let (thread_send, main_recv) = kanal::bounded(0); + let outgoing_queue: Arc>>> = + Arc::new(SynchronizedQueue::with_capacity( + // Allows 1024 packets per thread, should definitely be enough + controllable_team_info.controllables.len() * 1024, + )); for (i, controllable_info) in controllable_team_info.controllables.iter().enumerate() { - let (main_send, thread_recv) = kanal::bounded::(0); - let thread_send = thread_send.clone(); + let incoming_queue: Arc>> = + Arc::new(SynchronizedQueue::with_capacity(1024)); + // let thread_send = queue.clone(); let controllable_info = controllable_info.clone(); + let outgoing_queue = outgoing_queue.clone(); + threads.push(( - main_send, + incoming_queue.clone(), thread::Builder::new() .name(format!( "Agent thread {i} (spawn_id: {} index: {})", @@ -66,34 +102,56 @@ pub fn run_agents( )) .spawn(move || { let mut bot = T::new(controllable_info); + let mut incoming_queue_local = VecDeque::::with_capacity(8); + let mut outgoing_queue_local = PacketQueue::new(); + + loop { + let packet = match incoming_queue.try_dequeue() { + Ok(packets) => { + let mut iter = packets.into_iter(); + let first = iter.next().unwrap(); + incoming_queue_local.append(&mut iter.collect()); + first + } + Err(_) => { + let Some(packet) = incoming_queue_local.pop_front() else { + continue + }; + if incoming_queue_local.len() >= 8 { + // SKIP QUEUE + println!("WARN! Packet queue too long, skipping packets"); + incoming_queue_local.drain(..); + } + packet + } + }; - while let Ok(packet) = thread_recv.recv() { match packet { Packet::None => break, - Packet::GamePacket(x) => { - thread_send.send(bot.tick(x)).unwrap(); - } - Packet::FieldInfo(x) => thread_send.send(bot.on_field_info(x)).unwrap(), + Packet::GamePacket(x) => bot.tick(x, &mut outgoing_queue_local), + Packet::FieldInfo(x) => bot.on_field_info(x, &mut outgoing_queue_local), Packet::MatchSettings(x) => { - thread_send.send(bot.on_match_settings(x)).unwrap() + bot.on_match_settings(x, &mut outgoing_queue_local) } - Packet::MatchComm(x) => thread_send.send(bot.on_match_comm(x)).unwrap(), + Packet::MatchComm(x) => bot.on_match_comm(x, &mut outgoing_queue_local), Packet::BallPrediction(x) => { - thread_send.send(bot.on_ball_prediction(x)).unwrap() + bot.on_ball_prediction(x, &mut outgoing_queue_local) } - _ => { /* The rest of the packets are only client -> server */ } + _ => unreachable!() /* The rest of the packets are only client -> server */ } + + outgoing_queue.try_enqueue([outgoing_queue_local.empty()]).expect("Outgoing queue should be empty"); } - drop(thread_send); - drop(thread_recv); + // drop(thread_send); + // drop(thread_recv); }) .unwrap(), )); } - // drop never-again-used copy of thread_send - // NO NOT REMOVE, otherwise main_recv.recv() will never error - // which we rely on for clean exiting - drop(thread_send); + // // drop never-again-used copy of thread_send + // // NO NOT REMOVE, otherwise main_recv.recv() will never error + // // which we rely on for clean exiting + // drop(thread_send); // We only need to send one init complete with the first // spawn id even though we may be running multiple bots. @@ -104,46 +162,59 @@ pub fn run_agents( connection.send_packet(Packet::InitComplete)?; - // Main loop, broadcast packet to all of the bots, then wait for all of the responses - let mut to_send: Vec = Vec::with_capacity(controllable_team_info.controllables.len()); - 'main_loop: loop { - let packet = packets_to_process - .pop_front() - .unwrap_or(connection.recv_packet()?); - - for (sender, _) in threads.iter() { - let Ok(_) = sender.send(packet.clone()) else { - return Err(AgentError::AgentPanic); - }; + // Main loop, broadcast packet to all of the bots, then wait for all of the + // Rust limited to 32 for now, hopefully fixed in the future though not really a big deal + let mut to_send: [Vec; 32] = Default::default(); + let mut finished_thread_count = 0i64; + loop { + let mut maybe_packet = packets_to_process.pop_front(); + if maybe_packet.is_none() && connection.stream.peek(&mut 0u16.to_be_bytes()).is_ok() { + maybe_packet = Some(connection.recv_packet()?); + }; + + if let Some(packet) = maybe_packet { + for (thread_process_queue, _) in threads.iter() { + let Ok(_) = thread_process_queue.try_enqueue([packet.clone()]) else { + return Err(AgentError::AgentPanic); + }; + } } - for (_sender, _) in threads.iter() { - let Ok(list) = main_recv.recv() else { - break 'main_loop; - }; - to_send.extend(list.into_iter()) + let threads_len = threads.len() as i64; + + while finished_thread_count < threads_len { + if let Ok(messages) = outgoing_queue.try_dequeue() { + for msg in messages { + to_send[finished_thread_count as usize] = msg; + finished_thread_count += 1 + } + } + // if Instant::now().duration_since(start) + // // 1/120 of a second processing time - 250µs overhead + // > Duration::from_secs_f64(1. / 120. - 250. / 1_000_000.) + // { + // // println!("WARN! At least one thread was too slow to respond, skipping"); + // break; // Timeout, check next tick instead + // } } + finished_thread_count = 0; if to_send.is_empty() { continue; // no need to send nothing } - write_multiple_packets(&mut connection, mem::take(&mut to_send))?; - } - - for (_, thread_handle) in threads.into_iter() { - thread_handle.join().unwrap() + write_multiple_packets( + &mut connection, + mem::take(&mut to_send).into_iter().flatten(), + )?; } - - Ok(()) } fn write_multiple_packets( connection: &mut RLBotConnection, - packets: Vec, + packets: impl Iterator, ) -> Result<(), RLBotError> { let to_write = packets - .into_iter() // convert Packet to Vec that rlbot can understand .map(|x| { let data_type_bin = x.data_type().to_be_bytes().to_vec(); @@ -152,11 +223,11 @@ fn write_multiple_packets( [data_type_bin, data_len_bin, payload].concat() }) - .collect::>() - // Join all raw packets together - .concat(); + .flatten() + .collect::>(); connection.stream.write_all(&to_write)?; + connection.stream.flush()?; Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 3e26d2f..d94db90 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ use std::{ - io::{self, Read, Write}, - net::TcpStream, + io::{Read, Write}, + net::{AddrParseError, SocketAddr, TcpStream}, + str::FromStr, }; use planus::ReadAsRoot; @@ -30,9 +31,11 @@ pub enum PacketParseError { #[derive(Error, Debug)] pub enum RLBotError { #[error("Connection to RLBot failed")] - Connection(#[from] io::Error), + Connection(#[from] std::io::Error), #[error("Parsing packet failed")] PacketParseError(#[from] PacketParseError), + #[error("Invalid address, cannot parse")] + InvalidAddrError(#[from] AddrParseError), } #[allow(dead_code)] @@ -210,7 +213,8 @@ impl RLBotConnection { } pub fn new(addr: &str) -> Result { - let stream = TcpStream::connect(addr)?; + let stream = TcpStream::connect(SocketAddr::from_str(addr)?)?; + stream.set_nodelay(true)?; Ok(RLBotConnection {