From 223119654b4835c270ce4215f156e3c6236833b8 Mon Sep 17 00:00:00 2001 From: natrixaeria Date: Thu, 19 Dec 2019 20:16:20 +0100 Subject: Initial commit --- src/nn.rs | 105 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 src/nn.rs (limited to 'src/nn.rs') diff --git a/src/nn.rs b/src/nn.rs new file mode 100644 index 0000000..be600ea --- /dev/null +++ b/src/nn.rs @@ -0,0 +1,105 @@ +use rand::Rng; +use crate::game::{Game, GameLog, GameLogger, PointLogger, Status, Field}; +use crate::doublebuffer::DoubleBuffer; + +#[derive(Debug, Clone)] +pub struct Layer { + /// vector of (vectors of weights onto the past layer) + nodes: Vec>, +} + +#[derive(Debug, Clone)] +pub struct Nn { + layers: Vec +} + +fn randaround(x: f32) -> f32 { + rand::thread_rng().gen_range(-x, x) +} + +impl Nn { + pub fn new_empty(layers: &[u32]) -> Self { + Self { + layers: layers.iter().take(layers.len() - 1).zip(layers.iter().skip(1)).map(|(&n1, &n2)| Layer { + nodes: std::iter::repeat(std::iter::repeat(1.0).take(n1 as usize).collect()).take(n2 as usize).collect() + }).collect(), + } + } + + pub fn new_random(layers: &[u32]) -> Self { + Self { + layers: layers.iter().take(layers.len() - 1).zip(layers.iter().skip(1)).map(|(&n1, &n2)| Layer { + nodes: std::iter::repeat_with(|| (0..(n1 as usize)).map(|_| randaround(1.0)).collect()).take(n2 as usize).collect() + }).collect(), + } + } + + pub fn add_noise(&mut self, hard_noise: f32, fine_noise: f32, w: f32, r: f32) { + let complete = randaround(1.0).abs() < r; + for Layer { nodes: layer_mappings } in self.layers.iter_mut() { + for layer_mapping in layer_mappings.iter_mut() { + layer_mapping.iter_mut().for_each(|v| { + if complete { + *v = randaround(hard_noise); + } else { + *v += if randaround(1.0).abs() < w + { randaround(hard_noise) } else { randaround(fine_noise) } + } + }); + } + } + } + + pub fn new_by_game_res Self>(cols: u32, rows: u32, f: &F) -> Self { + // fields + 1xplayer_y + 1xplayer_v + let input_layer_size = cols * rows + 1 + 1; + f(&[input_layer_size, + rows * cols, + cols, + 1]) + } + + fn status_to_layer(status: &Status, _cols: u32, rows: u32) -> Vec { + [status.player[1] as f32 / rows as f32, status.player_v as f32].iter() + .map(|&x| x) + .chain( + status.fields.iter().map(|f| match f { + Field::Air => 0.0, + Field::Wall => 1.0, + Field::Spike => -1.0, + }) + ).collect() + } + + pub fn execute(&self, game: G) -> u64 { + self.execute_with_logger::(game) + } + + pub fn execute_logged(&self, game: G) -> GameLog { + self.execute_with_logger::(game) + } + + fn execute_with_logger(&self, mut game: G) -> L::Output { + let (cols, rows) = game.size(); + let mut log = L::new_empty(cols, rows); + log.append(&game.status()); + while let Some(status) = game.update() { + log.append(&status); + let input_layer = Self::status_to_layer(&status, cols, rows); + let mut db = DoubleBuffer::new(input_layer, vec![]); + for Layer {nodes: layer_mappings} in self.layers.iter().skip(1) { + *db.second_mut() = layer_mappings.iter().map(|layer_mapping| { + layer_mapping.iter().zip(db.first().iter()) + .map(|(w, v)| v * w) + .sum() + }).collect(); + db.switch(); + } + //println!("out: {:?} | threshold: {}", db.first(), cols); + if db.first()[0] >= cols as f32 { + game.jump(); + } + } + log.extract() + } +} -- cgit v1.2.3-54-g00ecf