use burn::{ nn::{Linear, Relu}, prelude::*, }; use nn::LinearConfig; use burn::record::{FullPrecisionSettings, Recorder}; use burn_import::pytorch::PyTorchFileRecorder; pub type BurnBackend = burn::backend::candle::Candle; #[derive(Module, Debug)] pub struct Net { input_lin: Linear, relu1: Relu, lin2: Linear, relu2: Relu, lin3: Linear, } impl Net { /// Create a new model. pub fn init(device: &B::Device) -> Self { let input_size = 8; let input_lin = LinearConfig::new(input_size, 4).init(device); let relu1 = Relu::new(); let lin2 = LinearConfig::new(4, 2).init(device); let relu2 = Relu::new(); let lin3 = LinearConfig::new(2, 1).init(device); Self { input_lin, relu1, lin2, relu2, lin3, } } /// Forward pass of the model. pub fn forward(&self, x: Tensor) -> Tensor { let x = self.input_lin.forward(x); let x = self.relu1.forward(x); let x = self.lin2.forward(x); let x = self.relu2.forward(x); self.lin3.forward(x) } } /// Load the p core model from the file in your source code (not in build.rs or script). pub fn load_model(path: &str) -> Net { let device = Default::default(); let record: NetRecord = PyTorchFileRecorder::::default() .load(path.into(), &device) .expect("Failed to decode state"); Net::::init(&device).load_record(record) }