summaryrefslogtreecommitdiff
path: root/src/trainer.rs
blob: 916d5157bb68774f09d0b71f82b995df20d4951e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
use crate::nn::Nn;
use crate::game::{Game, JumpGame};
use crate::gui::Gui;
use rand::Rng;

fn randto(x: f32) -> f32 {
    rand::thread_rng().gen_range(0.0, x)
}

pub struct Trainer {
}

pub fn train() {
    let (cols, rows) = (20, 16);
    //let colony_size = 64;
    let colony_size = 64;
    let thread_count = colony_size;

    let mut seed = 42;
    let mut best: Option<Nn> = None;
    for generation in 0.. {
        println!("Generation #{}", generation);
        println!("===============");
        println!("generating colony");
        let nn_colony: Vec<Nn> = if generation == 0 {
            let gen = || Nn::new_by_game_res(cols, rows, &Nn::new_random);
            std::iter::repeat_with(gen)
                .take(colony_size).collect()
        } else if let Some(nn) = best.clone() {
            let gen = || {let mut nn = nn.clone(); nn.add_noise(randto(2.0), randto(0.18), randto(1.0), 0.06); nn};
            std::iter::repeat_with(gen)
                .take(colony_size - 1).chain(std::iter::once(nn.clone())).collect()
        } else {
            panic!("no colony")
        };
        println!(" ...done");

        let points: Vec<u64> = nn_colony.chunks(thread_count).flat_map(|chunk| {
            println!("begin chunk");
            let threads: Vec<std::thread::JoinHandle<u64>> = chunk.iter().enumerate().map(|(i, nn)| {
                     println!("nn #{}", i);
                     let nn = nn.clone();
                     std::thread::spawn(move || {
                         let game = JumpGame::new(cols, rows, seed);
                         nn.execute(game)
                     }
                 )}).collect();
            let mut results = Vec::with_capacity(threads.len());
            for thread in threads {
                results.push(thread.join().unwrap());
            }
            results
        }).collect();
        best = points.iter().zip(nn_colony.iter())
                                .max_by_key(|(ref p, _)| p.clone())
                                .map(|(_, n)| n.clone());
        println!("points: {:?}", points);
        println!("max: {:?}", points.iter().max());

        if (generation % 25) == 0 {
            let game = JumpGame::new(cols, rows, seed);
            Gui::run(best.clone().unwrap().execute_logged(game));
            seed += 1;
        }
    }
}