summaryrefslogtreecommitdiff
path: root/src/trainer.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/trainer.rs')
-rw-r--r--src/trainer.rs66
1 files changed, 66 insertions, 0 deletions
diff --git a/src/trainer.rs b/src/trainer.rs
new file mode 100644
index 0000000..916d515
--- /dev/null
+++ b/src/trainer.rs
@@ -0,0 +1,66 @@
+use crate::nn::Nn;
+use crate::game::{Game, JumpGame};
+use crate::gui::Gui;
+use rand::Rng;
+
+fn randto(x: f32) -> f32 {
+ rand::thread_rng().gen_range(0.0, x)
+}
+
+pub struct Trainer {
+}
+
+pub fn train() {
+ let (cols, rows) = (20, 16);
+ //let colony_size = 64;
+ let colony_size = 64;
+ let thread_count = colony_size;
+
+ let mut seed = 42;
+ let mut best: Option<Nn> = None;
+ for generation in 0.. {
+ println!("Generation #{}", generation);
+ println!("===============");
+ println!("generating colony");
+ let nn_colony: Vec<Nn> = if generation == 0 {
+ let gen = || Nn::new_by_game_res(cols, rows, &Nn::new_random);
+ std::iter::repeat_with(gen)
+ .take(colony_size).collect()
+ } else if let Some(nn) = best.clone() {
+ let gen = || {let mut nn = nn.clone(); nn.add_noise(randto(2.0), randto(0.18), randto(1.0), 0.06); nn};
+ std::iter::repeat_with(gen)
+ .take(colony_size - 1).chain(std::iter::once(nn.clone())).collect()
+ } else {
+ panic!("no colony")
+ };
+ println!(" ...done");
+
+ let points: Vec<u64> = nn_colony.chunks(thread_count).flat_map(|chunk| {
+ println!("begin chunk");
+ let threads: Vec<std::thread::JoinHandle<u64>> = chunk.iter().enumerate().map(|(i, nn)| {
+ println!("nn #{}", i);
+ let nn = nn.clone();
+ std::thread::spawn(move || {
+ let game = JumpGame::new(cols, rows, seed);
+ nn.execute(game)
+ }
+ )}).collect();
+ let mut results = Vec::with_capacity(threads.len());
+ for thread in threads {
+ results.push(thread.join().unwrap());
+ }
+ results
+ }).collect();
+ best = points.iter().zip(nn_colony.iter())
+ .max_by_key(|(ref p, _)| p.clone())
+ .map(|(_, n)| n.clone());
+ println!("points: {:?}", points);
+ println!("max: {:?}", points.iter().max());
+
+ if (generation % 25) == 0 {
+ let game = JumpGame::new(cols, rows, seed);
+ Gui::run(best.clone().unwrap().execute_logged(game));
+ seed += 1;
+ }
+ }
+}