summaryrefslogtreecommitdiff
path: root/src/model.rs
blob: eb62c35b0f11b556d09d90ceeee15275639dcf9d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
use burn::{
    nn::{Linear, Relu},
    prelude::*,
};
use nn::LinearConfig;

use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;

pub type BurnBackend = burn::backend::candle::Candle;

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
    input_lin: Linear<B>,
    relu1: Relu,
    lin2: Linear<B>,
    relu2: Relu,
    lin3: Linear<B>,
}

impl<B: Backend> Net<B> {
    /// Create a new model.
    pub fn init(device: &B::Device) -> Self {
        let input_size = 8;
        let input_lin = LinearConfig::new(input_size, 4).init(device);
        let relu1 = Relu::new();
        let lin2 = LinearConfig::new(4, 2).init(device);
        let relu2 = Relu::new();
        let lin3 = LinearConfig::new(2, 1).init(device);
        Self {
            input_lin,
            relu1,
            lin2,
            relu2,
            lin3,
        }
    }

    /// Forward pass of the model.
    pub fn forward(&self, x: Tensor<B, 1>) -> Tensor<B, 1> {
        let x = self.input_lin.forward(x);
        let x = self.relu1.forward(x);
        let x = self.lin2.forward(x);
        let x = self.relu2.forward(x);

        self.lin3.forward(x)
    }
}

/// Load the p core model from the file in your source code (not in build.rs or script).
pub fn load_model(path: &str) -> Net<BurnBackend> {
    let device = Default::default();
    let record: NetRecord<BurnBackend> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
        .load(path.into(), &device)
        .expect("Failed to decode state");

    Net::<BurnBackend>::init(&device).load_record(record)
}