1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
|
use burn::{
nn::{Linear, Relu},
prelude::*,
};
use nn::LinearConfig;
use burn::record::{FullPrecisionSettings, Recorder};
use burn_import::pytorch::PyTorchFileRecorder;
pub type BurnBackend = burn::backend::candle::Candle;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
input_lin: Linear<B>,
relu1: Relu,
lin2: Linear<B>,
relu2: Relu,
lin3: Linear<B>,
}
impl<B: Backend> Net<B> {
/// Create a new model.
pub fn init(device: &B::Device) -> Self {
let input_size = 8;
let input_lin = LinearConfig::new(input_size, 4).init(device);
let relu1 = Relu::new();
let lin2 = LinearConfig::new(4, 2).init(device);
let relu2 = Relu::new();
let lin3 = LinearConfig::new(2, 1).init(device);
Self {
input_lin,
relu1,
lin2,
relu2,
lin3,
}
}
/// Forward pass of the model.
pub fn forward(&self, x: Tensor<B, 1>) -> Tensor<B, 1> {
let x = self.input_lin.forward(x);
let x = self.relu1.forward(x);
let x = self.lin2.forward(x);
let x = self.relu2.forward(x);
self.lin3.forward(x)
}
}
/// Load the p core model from the file in your source code (not in build.rs or script).
pub fn load_model(path: &str) -> Net<BurnBackend> {
let device = Default::default();
let record: NetRecord<BurnBackend> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
.load(path.into(), &device)
.expect("Failed to decode state");
Net::<BurnBackend>::init(&device).load_record(record)
}
|