diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/doublebuffer.rs | 31 | ||||
-rw-r--r-- | src/game.rs | 381 | ||||
-rw-r--r-- | src/gui.rs | 123 | ||||
-rw-r--r-- | src/main.rs | 38 | ||||
-rw-r--r-- | src/nn.rs | 105 | ||||
-rw-r--r-- | src/trainer.rs | 66 |
6 files changed, 744 insertions, 0 deletions
diff --git a/src/doublebuffer.rs b/src/doublebuffer.rs new file mode 100644 index 0000000..563f157 --- /dev/null +++ b/src/doublebuffer.rs @@ -0,0 +1,31 @@ +pub struct DoubleBuffer<T> { + a1: Vec<T>, + a2: Vec<T>, + switch: bool, +} + +impl<T> DoubleBuffer<T> { + pub fn new(a1: Vec<T>, a2: Vec<T>) -> Self { + Self { a1, a2, switch: false } + } + + pub fn switch(&mut self) { + self.switch = !self.switch; + } + + pub fn first(&self) -> &Vec<T> { + if self.switch { &self.a2 } else { &self.a1 } + } + + pub fn first_mut(&mut self) -> &mut Vec<T> { + if self.switch { &mut self.a2 } else { &mut self.a1 } + } + + pub fn second(&self) -> &Vec<T> { + if self.switch { &self.a1 } else { &self.a2 } + } + + pub fn second_mut(&mut self) -> &mut Vec<T> { + if self.switch { &mut self.a1 } else { &mut self.a2 } + } +} diff --git a/src/game.rs b/src/game.rs new file mode 100644 index 0000000..f6d0c51 --- /dev/null +++ b/src/game.rs @@ -0,0 +1,381 @@ +use rand::{RngCore, SeedableRng, rngs::StdRng}; + +const PLAYER_X: u32 = 3; +const FLOOR_HEIGHT: u32 = 1; +const G: u32 = 2; +const PLAYER_BOOST: u32 = 3; +const PLAYER_BOOST_TIME: u32 = 6; +const PLAYER_BOOST_MIN_TIME: u32 = 4; + +#[derive(Clone, Debug, Default)] +pub struct Status { + pub player: [u32; 2], + pub player_v: i32, + pub fields: Vec<Field>, +} + +pub trait Game: Send { + fn size(&self) -> (u32, u32); + fn update(&mut self) -> Option<Status>; + fn jump(&mut self); + fn status(&self) -> Status; + fn get_points(&self) -> u64; +} + +#[derive(Clone, Copy, PartialEq, Debug)] +pub enum Field { + Air, Wall, Spike +} + +impl Field { + pub fn repr(&self) -> char { + match self { + Field::Air => '.', + Field::Wall => 'Z', + Field::Spike => 'x', + } + } +} + +#[derive(Clone, Debug)] +struct PlayerBoost { + freq: u32, + lifetime: u32, +} + +impl PlayerBoost { + fn new() -> Self { + Self { + freq: 0, + lifetime: 0 + } + } + + fn activate(&mut self) { + *self = Self { + freq: 1, + lifetime: PLAYER_BOOST_TIME + } + } + + fn is(&self) -> bool { + self.lifetime != 0 + } + + fn get_boost(&self) -> u32 { + if self.lifetime > 0 { + let m = self.lifetime % self.freq; + if m == 0 { + PLAYER_BOOST + G + } else if PLAYER_BOOST > m { + PLAYER_BOOST + G - m + } else { G as u32 } + } else { 0 } + } + + fn update(&mut self) { + if self.lifetime == 0 { return; } + if self.lifetime <= PLAYER_BOOST_MIN_TIME { + self.freq += 1; + } + self.lifetime -= 1; + } +} + +#[derive(Clone, Debug)] +pub struct JumpGame { + grid: Vec<Field>, + cols: u32, rows: u32, + player_boost: PlayerBoost, + jumping: bool, + gap_len: u32, + platforms: [(u32, u32); 2], + spike: u32, spike_down: bool, + player: u32, + status: Status, + rng: StdRng, + points: u64, +} + +impl JumpGame { + pub fn new(cols: u32, rows: u32, seed: u64) -> Self { + let grid: Vec<Field> = std::iter::repeat(Field::Air) + .enumerate() + .map(|(i, v)| if (i as u32 % rows) <= FLOOR_HEIGHT + { Field::Wall } else { v }) + .take((cols * rows) as usize) + .collect(); + Self { + grid: grid.clone(), + cols, rows, + player_boost: PlayerBoost::new(), + jumping: false, + gap_len: 0, + platforms: [(0, 0); 2], + spike: 0, spike_down: false, + player: 3, + status: Status { + player: [ PLAYER_X, 3 ], + player_v: 0, + fields: grid + }, + rng: StdRng::seed_from_u64(seed), + points: 0, + } + } + + pub fn get_field(&self, col: u32, row: u32) -> Option<&Field> { + self.grid.get((row + col * self.rows) as usize) + } + + pub fn mut_row(&mut self, col: u32) -> Option<&mut [Field]> { + if col >= self.cols { None } + else { + let pos1 = (col * self.rows) as usize; + let pos2 = pos1 + self.rows as usize; + Some(&mut self.grid[pos1..pos2]) + } + } + + pub fn shift(&mut self) { + self.grid.rotate_left(self.rows as usize) + } + + pub fn normal_col(&self, n: u32) -> Field { + if n <= FLOOR_HEIGHT || self.collides(n) { + Field::Wall + } else { Field::Air } + } + + pub fn spike_col(&self, n: u32) -> Field { + if n <= FLOOR_HEIGHT { + Field::Wall + } else if self.spike <= n + 1 && self.spike + 0 >= n { + Field::Spike + } else { Field::Air } + } + + pub fn gap_col(&self, n: u32) -> Field { + if self.collides(n) { + Field::Wall + } else { Field::Air } + } + + pub fn gen_col_with<F: Fn(&Self, u32) -> Field>(&mut self, f: &F) { + let vec: Vec<Field> = std::iter::repeat(Field::Air) + .enumerate() + .map(|(i, _)| f(self, i as u32)) + .take(self.rows as usize) + .collect(); + self.mut_row(self.cols - 1).unwrap() + .clone_from_slice(&vec); + } + + pub fn gen_col(&mut self) { + if self.gap_len > 0 { + self.gen_col_with(&Self::gap_col) + } else if self.chance((1, 15)) { + self.gen_col_with(&Self::spike_col) + } else { + self.gen_col_with(&Self::normal_col) + } + } + + fn update_platforms(&mut self) { + for i in 0..self.platforms.len() { + let (_, l) = self.platforms[i]; + if l <= 0 { + if self.chance((1, 29)) { + self.platforms[i] = ( + self.randint(FLOOR_HEIGHT + 2, self.rows - 1), + self.randint(4, 8)); + } else { + self.platforms[i].1 = 0; + } + } else { + self.platforms[i].1 -= 1; + } + } + } + + fn collides(&self, n: u32) -> bool { + n != 0 && self.platforms.iter().any(|&(h, _)| h == n) + } + + pub fn update_gap(&mut self) { + if self.gap_len > 0 { + self.gap_len -= 1; + } else if self.chance((1, 12)) { + self.gap_len = self.randint(2, 5); + } + } + + pub fn update_spike(&mut self) { + if self.spike_down { + if self.spike <= 0 { + self.spike_down = false; + self.spike += 1 + } else { self.spike -= 1; } + } else { + if self.spike >= self.rows - 1 { + self.spike_down = true; + self.spike -= 1; + } else { self.spike += 1; } + } + } + + pub fn check_dead(&self) -> Option<()> { + self.get_field(PLAYER_X, self.player) + .filter(|&&f| f != Field::Spike) + .map(|_|()) + } + + pub fn apply_phys(&mut self) -> Option<()> { + let v = self.player_boost.get_boost() as i32; + let v = v - G as i32; + self.status.player_v = v; + if v != 0 { + let (oldpos, newpos) = (self.player as i32, self.player as i32 + v); + if newpos < 0 { return None; } + let mut endpos = None; + let up = oldpos < newpos; + let (mut it1, mut it2) = (oldpos..(newpos+1), (newpos..oldpos).rev()); + let steps: &mut dyn Iterator<Item=i32> = + if up { &mut it1 } + else { &mut it2 }; + for h in steps.skip(0) { + match self.get_field(PLAYER_X, h as u32) { + Some(Field::Air) => endpos = Some(h), + Some(Field::Wall) | None => { + if !up { self.jumping = false; } + break + }, + Some(Field::Spike) => return None, + } + } + if let Some(pos) = endpos { + let mut pos = pos as u32; + if pos >= self.rows { pos = self.rows - 1; } + self.player = pos as u32; + } + } + Some(()) + } + + pub fn update_status(&mut self) { + self.status.player[1] = self.player; + self.status.fields = self.grid.clone(); + } + + fn randint(&mut self, a: u32, b: u32) -> u32 { + (self.rng.next_u32() % (b - a + 1)) + a + } + + fn chance(&mut self, (c, n): (u32, u32)) -> bool { + self.randint(0, n - 1) < c + } +} + +impl Game for JumpGame { + fn size(&self) -> (u32, u32) { + (self.cols, self.rows) + } + + fn update(&mut self) -> Option<Status> { + self.player_boost.update(); + self.update_gap(); + self.update_spike(); + self.update_platforms(); + self.apply_phys()?; + self.shift(); + self.check_dead()?; + self.gen_col(); + self.update_status(); + self.points += 1; + Some(self.status()) + } + + fn jump(&mut self) { + if !self.jumping { + self.jumping = true; + self.player_boost.activate(); + } + } + + fn status(&self) -> Status { + self.status.clone() + } + + fn get_points(&self) -> u64 { + self.points + } +} + +pub trait GameLogger { + type Output; + fn new_empty(cols: u32, rows: u32) -> Self; + fn append(&mut self, status: &Status); + fn extract(self) -> Self::Output; +} + +pub struct PointLogger { + points: u64 +} + +impl GameLogger for PointLogger { + type Output = u64; + + fn new_empty(_cols: u32, _rows: u32) -> Self { + Self { points: 0 } + } + fn append(&mut self, _status: &Status) { + self.points += 1; + } + fn extract(self) -> Self::Output { + self.points + } +} + +#[derive(Clone, Debug)] +pub struct GameLog { + status: Vec<Status>, + frame: u32, + size: (u32, u32) +} + +impl GameLogger for GameLog { + type Output = Self; + fn new_empty(cols: u32, rows: u32) -> Self { + Self { + status: vec![], + frame: 0, + size: (cols, rows) + } + } + fn extract(self) -> Self::Output { + self + } + fn append(&mut self, status: &Status) { + self.status.push(status.clone()) + } +} + +impl Game for GameLog { + fn size(&self) -> (u32, u32) { self.size } + fn update(&mut self) -> Option<Status> { + if self.frame >= self.status.len() as u32 { + None + } else { + let status = self.status(); + self.frame += 1; + Some(status) + } + } + fn jump(&mut self) {} + fn status(&self) -> Status { + self.status[self.frame as usize].clone() + } + fn get_points(&self) -> u64 { + self.status.len() as u64 + } +} diff --git a/src/gui.rs b/src/gui.rs new file mode 100644 index 0000000..b13891c --- /dev/null +++ b/src/gui.rs @@ -0,0 +1,123 @@ +use ggez::event::EventHandler; +use crate::game::{Game, Status, Field}; +use std::sync::{Mutex, Arc, atomic::AtomicU32}; + +pub struct Gui { + ctx: ggez::Context, + event_loop: ggez::event::EventsLoop, +} + +impl Gui { + pub fn new() -> Self { + let (mut ctx, event_loop) = + ggez::ContextBuilder::new("neo x", "natrixaeria") + .window_setup(ggez::conf::WindowSetup::default() + .vsync(true)) + .build().unwrap(); + Self { + ctx, event_loop, + } + } + + pub fn run<G: Game + 'static>(mut game: G) -> std::thread::JoinHandle<()> { + std::thread::spawn(|| Self::new().run_in_thread(game)) + } + + fn run_in_thread<G: Game + 'static>(&mut self, mut game: G) { + let status = Arc::new(Mutex::new(Some(game.status()))); + let jump = Arc::new(AtomicU32::new(0)); + let mut handler = GameEventHandler { status: Arc::clone(&status), size: game.size(), jump: Arc::clone(&jump) }; + std::thread::spawn(move || { + let mut speed = 85.0; + while let Some(new_status) = game.update() { + if jump.fetch_and(0, std::sync::atomic::Ordering::SeqCst) > 0 { + game.jump(); + } + if let Ok(mut lock) = status.lock() { + *lock = Some(new_status); + } else { break; } + std::thread::sleep(std::time::Duration::from_millis(speed as u64)); + speed *= 0.9992; + println!("{:04} | {}", game.get_points(), speed); + } + println!("Points: {}", game.get_points()); + { *status.lock().unwrap() = None; } + }); + match ggez::event::run(&mut self.ctx, &mut self.event_loop, &mut handler) { + Err(_) => (), + _ => (), + } + } +} + +struct GameEventHandler { + status: Arc<Mutex<Option<Status>>>, + size: (u32, u32), + jump: Arc<AtomicU32>, +} + +impl GameEventHandler { + fn draw_rect(&mut self, ctx: &mut ggez::Context, translation: [f32; 4], color: ggez::graphics::Color, draw_mode: &ggez::graphics::DrawParam) -> ggez::GameResult<()> { + let rect = ggez::graphics::Mesh::new_rectangle( + ctx, + ggez::graphics::DrawMode::fill(), + ggez::graphics::Rect::new(translation[0], translation[1], translation[2], translation[3]), + color + )?; + ggez::graphics::draw(ctx, &rect, *draw_mode) + } + + fn draw_field(&mut self, ctx: &mut ggez::Context, col: u32, row: u32, color: ggez::graphics::Color, draw_mode: &ggez::graphics::DrawParam) -> ggez::GameResult<()> { + let (w, h) = ggez::graphics::size(ctx); + let (rx, ry) = self.size; + let (u, v) = (w / (rx as f32), h / (ry as f32)); + let (u, v) = if u < v { (u, u) } else { (v, v) }; + let (x, y) = (col as f32 * u, row as f32 * v); + let translation = [ x, y, u, v ]; + self.draw_rect(ctx, translation, color, draw_mode) + } +} + +impl EventHandler for GameEventHandler { + fn update(&mut self, ctx: &mut ggez::Context) -> ggez::GameResult<()> { + Ok(()) + } + + fn draw(&mut self, ctx: &mut ggez::Context) -> ggez::GameResult<()> { + let status = { (*self.status.lock().unwrap()).clone() }; + if status.is_none() { + println!("exxiittt"); + ggez::event::quit(ctx); + return Err(ggez::error::GameError::ResourceLoadError(format!("ending app"))); + } + let status = status.unwrap(); + let bg = ggez::graphics::BLACK; + let draw_mode = ggez::graphics::DrawParam::new(); + ggez::graphics::clear(ctx, bg); + + let (cols, rows) = self.size; + for row in 0..rows { + for col in 0..cols { + let color = match status.fields[((rows - row - 1) + col * rows) as usize] { + Field::Air => continue, + Field::Wall => ggez::graphics::Color::from_rgb(100, 100, 100), + Field::Spike => ggez::graphics::Color::from_rgb(200, 50, 6), + }; + self.draw_field(ctx, col, row, color, &draw_mode)?; + } + } + + let player_color = ggez::graphics::Color::from_rgb(0, 180, 250); + let (x, y) = (status.player[0], status.player[1]); + if rows >= y + 1 { + self.draw_field(ctx, x, rows - y - 1, player_color, &draw_mode)?; + } + ggez::graphics::present(ctx) + } + + fn key_down_event(&mut self, _ctx: &mut ggez::Context, keycode: ggez::event::KeyCode, _keymods: ggez::event::KeyMods, _repeat: bool) { + if keycode == ggez::event::KeyCode::Space { + self.jump.store(1, std::sync::atomic::Ordering::SeqCst); + } + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..a673753 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,38 @@ +mod doublebuffer; +mod nn; +mod trainer; +mod game; +mod gui; + +fn main() { + /*let (cols, rows) = (36, 12); + let mut game = JumpGame::new(cols, rows); + println!("\x1b[1;1H\x1b[2J"); + + while let Some(status) = game.update() { + /*print!("\x1b[1;1H"); + for row in 0..rows { + for col in 0..cols { + print!("{}", status.fields[((rows - row - 1) + col * rows) as usize].repr()); + } + println!(""); + } + std::thread::sleep_ms(80);*/ + }*/ + + /*let (cols, rows) = (18, 16); + let game = game::JumpGame::new(cols, rows, 3); + gui::Gui::new() + .run(game); + return;*/ + /*let mut n = nn::Nn::new_by_game_res(cols, rows); + let points = n.execute(game); + println!("{}", points);*/ + /*let mut n = nn::Nn::new_by_game_res(cols, rows, &nn::Nn::new_random); + let game = n.execute_logged(game); + gui::Gui::new() + .run(game);*/ + + trainer::train(); + //gui::Gui::run(game::JumpGame::new(18, 16, 4)).join().unwrap() +} diff --git a/src/nn.rs b/src/nn.rs new file mode 100644 index 0000000..be600ea --- /dev/null +++ b/src/nn.rs @@ -0,0 +1,105 @@ +use rand::Rng; +use crate::game::{Game, GameLog, GameLogger, PointLogger, Status, Field}; +use crate::doublebuffer::DoubleBuffer; + +#[derive(Debug, Clone)] +pub struct Layer { + /// vector of (vectors of weights onto the past layer) + nodes: Vec<Vec<f32>>, +} + +#[derive(Debug, Clone)] +pub struct Nn { + layers: Vec<Layer> +} + +fn randaround(x: f32) -> f32 { + rand::thread_rng().gen_range(-x, x) +} + +impl Nn { + pub fn new_empty(layers: &[u32]) -> Self { + Self { + layers: layers.iter().take(layers.len() - 1).zip(layers.iter().skip(1)).map(|(&n1, &n2)| Layer { + nodes: std::iter::repeat(std::iter::repeat(1.0).take(n1 as usize).collect()).take(n2 as usize).collect() + }).collect(), + } + } + + pub fn new_random(layers: &[u32]) -> Self { + Self { + layers: layers.iter().take(layers.len() - 1).zip(layers.iter().skip(1)).map(|(&n1, &n2)| Layer { + nodes: std::iter::repeat_with(|| (0..(n1 as usize)).map(|_| randaround(1.0)).collect()).take(n2 as usize).collect() + }).collect(), + } + } + + pub fn add_noise(&mut self, hard_noise: f32, fine_noise: f32, w: f32, r: f32) { + let complete = randaround(1.0).abs() < r; + for Layer { nodes: layer_mappings } in self.layers.iter_mut() { + for layer_mapping in layer_mappings.iter_mut() { + layer_mapping.iter_mut().for_each(|v| { + if complete { + *v = randaround(hard_noise); + } else { + *v += if randaround(1.0).abs() < w + { randaround(hard_noise) } else { randaround(fine_noise) } + } + }); + } + } + } + + pub fn new_by_game_res<F: Fn(&[u32]) -> Self>(cols: u32, rows: u32, f: &F) -> Self { + // fields + 1xplayer_y + 1xplayer_v + let input_layer_size = cols * rows + 1 + 1; + f(&[input_layer_size, + rows * cols, + cols, + 1]) + } + + fn status_to_layer(status: &Status, _cols: u32, rows: u32) -> Vec<f32> { + [status.player[1] as f32 / rows as f32, status.player_v as f32].iter() + .map(|&x| x) + .chain( + status.fields.iter().map(|f| match f { + Field::Air => 0.0, + Field::Wall => 1.0, + Field::Spike => -1.0, + }) + ).collect() + } + + pub fn execute<G: Game>(&self, game: G) -> u64 { + self.execute_with_logger::<G, PointLogger>(game) + } + + pub fn execute_logged<G: Game>(&self, game: G) -> GameLog { + self.execute_with_logger::<G, GameLog>(game) + } + + fn execute_with_logger<G: Game, L: GameLogger>(&self, mut game: G) -> L::Output { + let (cols, rows) = game.size(); + let mut log = L::new_empty(cols, rows); + log.append(&game.status()); + while let Some(status) = game.update() { + log.append(&status); + let input_layer = Self::status_to_layer(&status, cols, rows); + let mut db = DoubleBuffer::new(input_layer, vec![]); + for Layer {nodes: layer_mappings} in self.layers.iter().skip(1) { + *db.second_mut() = layer_mappings.iter().map(|layer_mapping| { + layer_mapping.iter().zip(db.first().iter()) + .map(|(w, v)| v * w) + .sum() + }).collect(); + db.switch(); + } + //println!("out: {:?} | threshold: {}", db.first(), cols); + if db.first()[0] >= cols as f32 { + game.jump(); + } + } + log.extract() + } +} diff --git a/src/trainer.rs b/src/trainer.rs new file mode 100644 index 0000000..916d515 --- /dev/null +++ b/src/trainer.rs @@ -0,0 +1,66 @@ +use crate::nn::Nn; +use crate::game::{Game, JumpGame}; +use crate::gui::Gui; +use rand::Rng; + +fn randto(x: f32) -> f32 { + rand::thread_rng().gen_range(0.0, x) +} + +pub struct Trainer { +} + +pub fn train() { + let (cols, rows) = (20, 16); + //let colony_size = 64; + let colony_size = 64; + let thread_count = colony_size; + + let mut seed = 42; + let mut best: Option<Nn> = None; + for generation in 0.. { + println!("Generation #{}", generation); + println!("==============="); + println!("generating colony"); + let nn_colony: Vec<Nn> = if generation == 0 { + let gen = || Nn::new_by_game_res(cols, rows, &Nn::new_random); + std::iter::repeat_with(gen) + .take(colony_size).collect() + } else if let Some(nn) = best.clone() { + let gen = || {let mut nn = nn.clone(); nn.add_noise(randto(2.0), randto(0.18), randto(1.0), 0.06); nn}; + std::iter::repeat_with(gen) + .take(colony_size - 1).chain(std::iter::once(nn.clone())).collect() + } else { + panic!("no colony") + }; + println!(" ...done"); + + let points: Vec<u64> = nn_colony.chunks(thread_count).flat_map(|chunk| { + println!("begin chunk"); + let threads: Vec<std::thread::JoinHandle<u64>> = chunk.iter().enumerate().map(|(i, nn)| { + println!("nn #{}", i); + let nn = nn.clone(); + std::thread::spawn(move || { + let game = JumpGame::new(cols, rows, seed); + nn.execute(game) + } + )}).collect(); + let mut results = Vec::with_capacity(threads.len()); + for thread in threads { + results.push(thread.join().unwrap()); + } + results + }).collect(); + best = points.iter().zip(nn_colony.iter()) + .max_by_key(|(ref p, _)| p.clone()) + .map(|(_, n)| n.clone()); + println!("points: {:?}", points); + println!("max: {:?}", points.iter().max()); + + if (generation % 25) == 0 { + let game = JumpGame::new(cols, rows, seed); + Gui::run(best.clone().unwrap().execute_logged(game)); + seed += 1; + } + } +} |