From a88b9d2078ecbe4c33dac3f10c629b24b3be99be Mon Sep 17 00:00:00 2001 From: swz-git Date: Wed, 23 Oct 2024 01:33:23 +0200 Subject: [PATCH] bad tokio impl --- Cargo.toml | 4 +- examples/atba_agent/main.rs | 46 +++++++++-- src/agents.rs | 159 ++++++++++++++++++++++++++++++++---- src/lib.rs | 62 +++++++++++++- 4 files changed, 248 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 98a5bcf..5157401 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,11 +18,13 @@ codegen-units = 1 panic = "abort" [dependencies] -kanal = { version = "0.1.0-pre8", default-features = false } +kanal = { version = "0.1.0-pre8" } glam = { version = "0.27.0", optional = true } serde = { version = "1.0.210", features = ["derive"] } thiserror = "1.0.50" planus = "1.0.0" +tokio = { version = "1.41.0", features = ["full"] } +futures = "0.3.31" [features] default = ["glam"] diff --git a/examples/atba_agent/main.rs b/examples/atba_agent/main.rs index bc741f1..c6fb61f 100644 --- a/examples/atba_agent/main.rs +++ b/examples/atba_agent/main.rs @@ -1,9 +1,9 @@ use std::{env, f32::consts::PI}; use rlbot_interface::{ - agents::{run_agents, Agent}, + agents::{run_agents, run_agents_tokio, Agent}, rlbot::{ConnectionSettings, ControllableInfo, ControllerState, PlayerInput}, - Packet, RLBotConnection, + Packet, RLBotConnection, RLBotConnectionTokio, }; struct AtbaAgent { @@ -68,12 +68,16 @@ impl Agent for AtbaAgent { packets_to_send } } -fn main() { + +#[tokio::main] +async fn main() { println!("Connecting"); let rlbot_addr = env::var("RLBOT_CORE_ADDR").unwrap_or("127.0.0.1:23234".to_owned()); - let rlbot_connection = RLBotConnection::new(&rlbot_addr).expect("connection"); + let rlbot_connection = RLBotConnectionTokio::new(&rlbot_addr) + .await + .expect("connection"); println!("Running!"); @@ -84,8 +88,7 @@ fn main() { let agent_id = env::var("RLBOT_AGENT_ID").unwrap_or("rlbot/rust-example-bot".into()); - // Blocking - run_agents::( + run_agents_tokio::( ConnectionSettings { agent_id: agent_id.clone(), wants_ball_predictions: true, @@ -94,7 +97,38 @@ fn main() { }, rlbot_connection, ) + .await .expect("to run agent"); println!("Agent(s) with agent_id `{agent_id}` exited nicely") } + +// fn main() { +// println!("Connecting"); + +// let rlbot_addr = env::var("RLBOT_CORE_ADDR").unwrap_or("127.0.0.1:23234".to_owned()); + +// let rlbot_connection = RLBotConnection::new(&rlbot_addr).expect("connection"); + +// println!("Running!"); + +// // The hivemind field in your bot.toml file decides if rlbot core is going to +// // start your bot as one or multiple instances of your binary/exe. +// // If the hivemind field is set to true, one instance of your bot will handle +// // all of the bots in a team. + +// let agent_id = env::var("RLBOT_AGENT_ID").unwrap_or("rlbot/rust-example-bot".into()); + +// run_agents::( +// ConnectionSettings { +// agent_id: agent_id.clone(), +// wants_ball_predictions: true, +// wants_comms: true, +// close_after_match: true, +// }, +// rlbot_connection, +// ) +// .expect("to run agent"); + +// println!("Agent(s) with agent_id `{agent_id}` exited nicely") +// } diff --git a/src/agents.rs b/src/agents.rs index 1e56e03..a675b68 100644 --- a/src/agents.rs +++ b/src/agents.rs @@ -1,6 +1,9 @@ -use std::{collections::VecDeque, io::Write, mem, thread}; +use std::{collections::VecDeque, io::Write, iter, mem, sync::Arc, thread, time::Instant}; -use crate::{rlbot::*, Packet, RLBotConnection, RLBotError}; +use futures::FutureExt; +use tokio::{io::AsyncWriteExt, task::block_in_place}; + +use crate::{rlbot::*, Packet, RLBotConnection, RLBotConnectionTokio, RLBotError}; #[allow(unused_variables)] pub trait Agent { @@ -128,7 +131,15 @@ pub fn run_agents( continue; // no need to send nothing } - write_multiple_packets(&mut connection, mem::take(&mut to_send))?; + let raw_packets = build_multiple_packets(&mut connection.builder, mem::take(&mut to_send)); + connection.stream.write_all(&raw_packets).map_err(|e| { + let new_e: RLBotError = e.into(); + new_e + })?; + connection.stream.flush().map_err(|e| { + let new_e: RLBotError = e.into(); + new_e + })?; } for (_, thread_handle) in threads.into_iter() { @@ -138,25 +149,145 @@ pub fn run_agents( Ok(()) } -fn write_multiple_packets( - connection: &mut RLBotConnection, - packets: Vec, -) -> Result<(), RLBotError> { - let to_write = packets +/// Run multiple agents using tokio (async). They share a connection. +/// Ok(()) means a successful exit; one of the bots received a None packet. +pub async fn run_agents_tokio( + connection_settings: ConnectionSettings, + mut connection: RLBotConnectionTokio, +) -> Result<(), AgentError> { + connection.send_packet(connection_settings).await?; + + let mut packets_to_process = VecDeque::new(); + + // Wait for Controllable(Team)Info to know which indices we control + let controllable_team_info = loop { + let packet = connection.recv_packet().await?; + if let Packet::ControllableTeamInfo(x) = packet { + break x; + } else { + packets_to_process.push_back(packet); + continue; + } + }; + + let mut threads = vec![]; + + let (thread_send, main_recv) = kanal::bounded_async(controllable_team_info.controllables.len()); + for (i, controllable_info) in controllable_team_info.controllables.iter().enumerate() { + let (main_send, thread_recv) = kanal::bounded_async::(0); + let thread_send = thread_send.clone(); + let controllable_info = controllable_info.clone(); + + threads.push(( + main_send, + tokio::spawn(async move { + let mut bot = T::new(controllable_info); + + while let Ok(packet) = thread_recv.recv().await { + match packet { + Packet::None => break, + Packet::GamePacket(x) => { + thread_send.send(bot.tick(x)).await.expect("thread send"); + } + Packet::FieldInfo(x) => { + thread_send + .send(bot.on_field_info(x)) + .await + .expect("thread send"); + } + Packet::MatchSettings(x) => { + thread_send + .send(bot.on_match_settings(x)) + .await + .expect("thread send"); + } + Packet::MatchComm(x) => { + thread_send + .send(bot.on_match_comm(x)) + .await + .expect("thread send"); + } + Packet::BallPrediction(x) => { + thread_send + .send(bot.on_ball_prediction(x)) + .await + .expect("thread send"); + } + _ => { /* The rest of the packets are only client -> server */ } + } + } + drop(thread_send); + drop(thread_recv); + }), + )); + } + // drop never-again-used copy of thread_send + // NO NOT REMOVE, otherwise main_recv.recv() will never error + // which we rely on for clean exiting + drop(thread_send); + + // We only need to send one init complete with the first + // spawn id even though we may be running multiple bots. + if controllable_team_info.controllables.is_empty() { + // run no bots? no problem, done + return Ok(()); + }; + + connection.send_packet(Packet::InitComplete).await?; + + 'main_loop: loop { + let inst = Instant::now(); + let packet = packets_to_process + .pop_front() + .unwrap_or(connection.recv_packet().await?); + // let packet = Arc::new(packet); + + // dbg!(Instant::now().duration_since(inst)); + let inst = Instant::now(); + + for (sender, _thread) in &threads { + if let Err(x) = sender.send(packet.clone()).await { + return Err(AgentError::AgentPanic); + } + } + + // dbg!(Instant::now().duration_since(inst)); + let inst = Instant::now(); + + for _i in 0..threads.len() { + let r = main_recv.recv().await.unwrap(); + if r.len() == 0 { + continue; + } + connection + .stream + .write_all(&build_multiple_packets(&mut connection.builder, r)) + .await + .map_err(|e| { + let new_e: RLBotError = e.into(); + new_e + })?; + } + connection.stream.flush().await.unwrap(); + + // dbg!(Instant::now().duration_since(inst)); + } + + // Ok(()) +} + +fn build_multiple_packets(builder: &mut planus::Builder, packets: Vec) -> Vec { + packets .into_iter() // convert Packet to Vec that rlbot can understand .map(|x| { let data_type_bin = x.data_type().to_be_bytes().to_vec(); - let payload = x.build(&mut connection.builder); + let payload = x.build(builder); let data_len_bin = (payload.len() as u16).to_be_bytes().to_vec(); [data_type_bin, data_len_bin, payload].concat() }) .collect::>() // Join all raw packets together - .concat(); - - connection.stream.write_all(&to_write)?; - - Ok(()) + .concat() } diff --git a/src/lib.rs b/src/lib.rs index 3e26d2f..888e585 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,10 @@ pub mod agents; #[cfg(feature = "glam")] pub use glam; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + task::block_in_place, +}; pub(crate) mod flat_wrapper; @@ -209,11 +213,65 @@ impl RLBotConnection { Ok(packet) } - pub fn new(addr: &str) -> Result { + pub fn new(addr: &str) -> Result { let stream = TcpStream::connect(addr)?; stream.set_nodelay(true)?; - Ok(RLBotConnection { + Ok(Self { + stream, + builder: planus::Builder::with_capacity(1024), + recv_buf: [0u8; u16::MAX as usize], + }) + } +} + +pub struct RLBotConnectionTokio { + stream: tokio::net::TcpStream, + builder: planus::Builder, + recv_buf: [u8; u16::MAX as usize], +} + +impl RLBotConnectionTokio { + async fn send_packet_enum(&mut self, packet: Packet) -> Result<(), RLBotError> { + let data_type_bin = packet.data_type().to_be_bytes().to_vec(); + let payload = block_in_place(|| packet.build(&mut self.builder)); + let data_len_bin = (payload.len() as u16).to_be_bytes().to_vec(); + + // Join so we make sure everything gets written in the right order + let joined = [data_type_bin, data_len_bin, payload].concat(); + + self.stream.write_all(&joined).await?; + self.stream.flush().await?; + Ok(()) + } + + pub async fn send_packet>(&mut self, packet: P) -> Result<(), RLBotError> { + self.send_packet_enum(packet.into()).await + } + + pub async fn recv_packet(&mut self) -> Result { + let mut buf = [0u8; 4]; + + // TODO: disable work stealing here if this causes problems + self.stream.read_exact(&mut buf).await?; + + let data_type = u16::from_be_bytes([buf[0], buf[1]]); + let data_len = u16::from_be_bytes([buf[2], buf[3]]); + + let buf = &mut self.recv_buf[0..data_len as usize]; + + self.stream.read_exact(buf).await?; + + let packet = Packet::from_payload(data_type, buf)?; + + Ok(packet) + } + + pub async fn new(addr: &str) -> Result { + let stream = tokio::net::TcpStream::connect(addr).await?; + stream.set_nodelay(true)?; + + Ok(Self { stream, builder: planus::Builder::with_capacity(1024), recv_buf: [0u8; u16::MAX as usize],