From 223119654b4835c270ce4215f156e3c6236833b8 Mon Sep 17 00:00:00 2001 From: natrixaeria Date: Thu, 19 Dec 2019 20:16:20 +0100 Subject: Initial commit --- src/trainer.rs | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 src/trainer.rs (limited to 'src/trainer.rs') diff --git a/src/trainer.rs b/src/trainer.rs new file mode 100644 index 0000000..916d515 --- /dev/null +++ b/src/trainer.rs @@ -0,0 +1,66 @@ +use crate::nn::Nn; +use crate::game::{Game, JumpGame}; +use crate::gui::Gui; +use rand::Rng; + +fn randto(x: f32) -> f32 { + rand::thread_rng().gen_range(0.0, x) +} + +pub struct Trainer { +} + +pub fn train() { + let (cols, rows) = (20, 16); + //let colony_size = 64; + let colony_size = 64; + let thread_count = colony_size; + + let mut seed = 42; + let mut best: Option = None; + for generation in 0.. { + println!("Generation #{}", generation); + println!("==============="); + println!("generating colony"); + let nn_colony: Vec = if generation == 0 { + let gen = || Nn::new_by_game_res(cols, rows, &Nn::new_random); + std::iter::repeat_with(gen) + .take(colony_size).collect() + } else if let Some(nn) = best.clone() { + let gen = || {let mut nn = nn.clone(); nn.add_noise(randto(2.0), randto(0.18), randto(1.0), 0.06); nn}; + std::iter::repeat_with(gen) + .take(colony_size - 1).chain(std::iter::once(nn.clone())).collect() + } else { + panic!("no colony") + }; + println!(" ...done"); + + let points: Vec = nn_colony.chunks(thread_count).flat_map(|chunk| { + println!("begin chunk"); + let threads: Vec> = chunk.iter().enumerate().map(|(i, nn)| { + println!("nn #{}", i); + let nn = nn.clone(); + std::thread::spawn(move || { + let game = JumpGame::new(cols, rows, seed); + nn.execute(game) + } + )}).collect(); + let mut results = Vec::with_capacity(threads.len()); + for thread in threads { + results.push(thread.join().unwrap()); + } + results + }).collect(); + best = points.iter().zip(nn_colony.iter()) + .max_by_key(|(ref p, _)| p.clone()) + .map(|(_, n)| n.clone()); + println!("points: {:?}", points); + println!("max: {:?}", points.iter().max()); + + if (generation % 25) == 0 { + let game = JumpGame::new(cols, rows, seed); + Gui::run(best.clone().unwrap().execute_logged(game)); + seed += 1; + } + } +} -- cgit v1.2.3-54-g00ecf