summaryrefslogtreecommitdiff
path: root/src/nn.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/nn.rs')
-rw-r--r--src/nn.rs105
1 files changed, 105 insertions, 0 deletions
diff --git a/src/nn.rs b/src/nn.rs
new file mode 100644
index 0000000..be600ea
--- /dev/null
+++ b/src/nn.rs
@@ -0,0 +1,105 @@
+use rand::Rng;
+use crate::game::{Game, GameLog, GameLogger, PointLogger, Status, Field};
+use crate::doublebuffer::DoubleBuffer;
+
+#[derive(Debug, Clone)]
+pub struct Layer {
+ /// vector of (vectors of weights onto the past layer)
+ nodes: Vec<Vec<f32>>,
+}
+
+#[derive(Debug, Clone)]
+pub struct Nn {
+ layers: Vec<Layer>
+}
+
+fn randaround(x: f32) -> f32 {
+ rand::thread_rng().gen_range(-x, x)
+}
+
+impl Nn {
+ pub fn new_empty(layers: &[u32]) -> Self {
+ Self {
+ layers: layers.iter().take(layers.len() - 1).zip(layers.iter().skip(1)).map(|(&n1, &n2)| Layer {
+ nodes: std::iter::repeat(std::iter::repeat(1.0).take(n1 as usize).collect()).take(n2 as usize).collect()
+ }).collect(),
+ }
+ }
+
+ pub fn new_random(layers: &[u32]) -> Self {
+ Self {
+ layers: layers.iter().take(layers.len() - 1).zip(layers.iter().skip(1)).map(|(&n1, &n2)| Layer {
+ nodes: std::iter::repeat_with(|| (0..(n1 as usize)).map(|_| randaround(1.0)).collect()).take(n2 as usize).collect()
+ }).collect(),
+ }
+ }
+
+ pub fn add_noise(&mut self, hard_noise: f32, fine_noise: f32, w: f32, r: f32) {
+ let complete = randaround(1.0).abs() < r;
+ for Layer { nodes: layer_mappings } in self.layers.iter_mut() {
+ for layer_mapping in layer_mappings.iter_mut() {
+ layer_mapping.iter_mut().for_each(|v| {
+ if complete {
+ *v = randaround(hard_noise);
+ } else {
+ *v += if randaround(1.0).abs() < w
+ { randaround(hard_noise) } else { randaround(fine_noise) }
+ }
+ });
+ }
+ }
+ }
+
+ pub fn new_by_game_res<F: Fn(&[u32]) -> Self>(cols: u32, rows: u32, f: &F) -> Self {
+ // fields + 1xplayer_y + 1xplayer_v
+ let input_layer_size = cols * rows + 1 + 1;
+ f(&[input_layer_size,
+ rows * cols,
+ cols,
+ 1])
+ }
+
+ fn status_to_layer(status: &Status, _cols: u32, rows: u32) -> Vec<f32> {
+ [status.player[1] as f32 / rows as f32, status.player_v as f32].iter()
+ .map(|&x| x)
+ .chain(
+ status.fields.iter().map(|f| match f {
+ Field::Air => 0.0,
+ Field::Wall => 1.0,
+ Field::Spike => -1.0,
+ })
+ ).collect()
+ }
+
+ pub fn execute<G: Game>(&self, game: G) -> u64 {
+ self.execute_with_logger::<G, PointLogger>(game)
+ }
+
+ pub fn execute_logged<G: Game>(&self, game: G) -> GameLog {
+ self.execute_with_logger::<G, GameLog>(game)
+ }
+
+ fn execute_with_logger<G: Game, L: GameLogger>(&self, mut game: G) -> L::Output {
+ let (cols, rows) = game.size();
+ let mut log = L::new_empty(cols, rows);
+ log.append(&game.status());
+ while let Some(status) = game.update() {
+ log.append(&status);
+ let input_layer = Self::status_to_layer(&status, cols, rows);
+ let mut db = DoubleBuffer::new(input_layer, vec![]);
+ for Layer {nodes: layer_mappings} in self.layers.iter().skip(1) {
+ *db.second_mut() = layer_mappings.iter().map(|layer_mapping| {
+ layer_mapping.iter().zip(db.first().iter())
+ .map(|(w, v)| v * w)
+ .sum()
+ }).collect();
+ db.switch();
+ }
+ //println!("out: {:?} | threshold: {}", db.first(), cols);
+ if db.first()[0] >= cols as f32 {
+ game.jump();
+ }
+ }
+ log.extract()
+ }
+}