diff options
author | Dennis Kobert <dennis@kobert.dev> | 2025-04-02 10:20:43 +0200 |
---|---|---|
committer | Dennis Kobert <dennis@kobert.dev> | 2025-04-02 10:20:43 +0200 |
commit | c8c05d29419822aff3554af788e910ec69267406 (patch) | |
tree | f978c5f3720c1075c2cc9a12fc029e874e3a1c7c | |
parent | 1c31c2d7068737af583d76ec0e7dc12125a5c34d (diff) |
Implement energy model for e cores
-rw-r--r-- | src/benchmark.rs | 126 | ||||
-rw-r--r-- | src/energy.rs | 84 | ||||
-rw-r--r-- | src/energy/estimator.rs | 2 | ||||
-rw-r--r-- | src/energy/trackers/kernel.rs | 3 | ||||
-rw-r--r-- | src/energy/trackers/mock.rs | 4 | ||||
-rw-r--r-- | src/energy/trackers/perf.rs | 47 | ||||
-rw-r--r-- | src/main.rs | 17 | ||||
-rw-r--r-- | src/model.rs | 16 |
8 files changed, 167 insertions, 132 deletions
diff --git a/src/benchmark.rs b/src/benchmark.rs index 77202a8..4876159 100644 --- a/src/benchmark.rs +++ b/src/benchmark.rs @@ -10,20 +10,20 @@ use perf_event::{ use rand::seq::IteratorRandom; use scx_utils::Topology; use scx_utils::UserExitInfo; -use std::fs::File; use std::mem::MaybeUninit; use std::process; use std::thread; use std::time::{Duration, Instant}; use std::{collections::HashMap, i32, ops::Range}; +use std::{fs::File, sync::atomic::AtomicI32}; -const SLICE_US: u64 = 50000; -const LOG_INTERVAL_MS: u64 = 10; // Log every 1 second - // const RESHUFFLE_ROUNDS: usize = 5; // Number of rounds before reshuffling counters -const RESHUFFLE_ROUNDS: usize = 1; // Number of rounds before reshuffling counters +const SLICE_US: u64 = 5000; +const LOG_INTERVAL_MS: u64 = 10; +const RESHUFFLE_ROUNDS: usize = 1000; // Number of rounds before changing cpu const MAX_COUNTERS_AT_ONCE_P_CORE: usize = 7; const MAX_COUNTERS_AT_ONCE_E_CORE: usize = 8; type Pid = i32; +static CPU: AtomicI32 = AtomicI32::new(0); pub struct BenchmarkScheduler<'a> { bpf: BpfScheduler<'a>, @@ -69,7 +69,7 @@ impl Measurement { } // Take a measurement with the given counter group - fn take(counters: &[(String, Counter)], group: &mut Group) -> Result<Self> { + fn take(counters: &[(String, Counter)], group: &mut Group, cpu_id: u32) -> Result<Self> { let mut measurement = Self::new(); // Read energy @@ -77,7 +77,7 @@ impl Measurement { measurement.energy = rapl::read_package_energy().ok(); // Read CPU frequency - measurement.frequency = read_cpu_frequency(0); + measurement.frequency = read_cpu_frequency(cpu_id); // Read performance counters let counts = group.read()?; @@ -118,6 +118,7 @@ impl Measurement { energy_delta, frequency: self.frequency, counter_deltas: self.counter_values.clone(), + e_core: 0, } } } @@ -125,6 +126,7 @@ impl Measurement { // Represents the difference between two measurements struct MeasurementDiff { timestamp: Instant, + e_core: u8, duration_ms: u64, energy_delta: f64, frequency: Option<f64>, @@ -138,6 +140,7 @@ impl MeasurementDiff { let mut record = vec![ self.timestamp.elapsed().as_secs_f64().to_string(), self.duration_ms.to_string(), + self.e_core.to_string(), self.energy_delta.to_string(), self.frequency .map(|f| f.to_string()) @@ -201,23 +204,12 @@ impl<'a> BenchmarkScheduler<'a> { while let Ok(Some(task)) = self.bpf.dequeue_task() { let mut dispatched_task = DispatchedTask::new(&task); - match self.mode { - // If it's our own process, schedule it to core 1 - Mode::PCores => { - if task.pid == self.own_pid { - dispatched_task.cpu = self.p_cores.start + 1; - } else { - // Schedule all other tasks on core 0 - dispatched_task.cpu = self.p_cores.start; - } - } - Mode::ECores => { - if task.pid == self.own_pid { - dispatched_task.cpu = self.e_cores.start + 1; - } else { - dispatched_task.cpu = self.e_cores.start; - } - } + let cpu = CPU.load(std::sync::atomic::Ordering::Relaxed); + if task.pid == self.own_pid { + dispatched_task.cpu = cpu + 1; + } else { + // Schedule all other tasks on core 0 + dispatched_task.cpu = cpu; } dispatched_task.slice_ns = SLICE_US; @@ -240,15 +232,7 @@ impl<'a> BenchmarkScheduler<'a> { let e_cores = self.e_cores.clone(); let p_cores = self.p_cores.clone(); thread::spawn(move || { - if let Err(e) = run_measurement_loop( - log_path, - mode, - if mode == Mode::PCores { - p_cores.start - } else { - e_cores.start - }, - ) { + if let Err(e) = run_measurement_loop(log_path, mode, p_cores.start, e_cores.start) { eprintln!("Measurement thread error: {:?}", e); } }) @@ -269,7 +253,7 @@ impl<'a> BenchmarkScheduler<'a> { } // Main measurement loop -fn run_measurement_loop(log_path: String, mode: Mode, cpu_to_monitor: i32) -> Result<()> { +fn run_measurement_loop(log_path: String, mode: Mode, p_core: i32, e_core: i32) -> Result<()> { // Define available hardware counters let available_events = define_available_events(); @@ -279,12 +263,16 @@ fn run_measurement_loop(log_path: String, mode: Mode, cpu_to_monitor: i32) -> Re let mut rng = rand::rng(); let mut round_counter = 0; + let mut cpu_to_monitor = p_core; println!("Monitoring: {cpu_to_monitor}"); // Main measurement loop loop { // println!("Starting new counter group (round {})", round_counter); round_counter += 1; + let is_e_core = round_counter % 2 == 0; + cpu_to_monitor = if is_e_core { p_core } else { e_core }; + CPU.store(cpu_to_monitor, std::sync::atomic::Ordering::Relaxed); // Create a new perf group let mut group = match Group::new_with_pid_and_cpu(-1, cpu_to_monitor) { @@ -297,14 +285,9 @@ fn run_measurement_loop(log_path: String, mode: Mode, cpu_to_monitor: i32) -> Re }; // Select random subset of counters - let selected_events = available_events.iter().choose_multiple( - &mut rng, - if mode == Mode::PCores { - MAX_COUNTERS_AT_ONCE_P_CORE - } else { - MAX_COUNTERS_AT_ONCE_E_CORE - }, - ); + let selected_events = available_events + .iter() + .choose_multiple(&mut rng, MAX_COUNTERS_AT_ONCE_P_CORE); //let selected_events = available_events[0..MAX_COUNTERS_AT_ONCE_E_CORE].iter(); @@ -350,14 +333,15 @@ fn run_measurement_loop(log_path: String, mode: Mode, cpu_to_monitor: i32) -> Re // ); // Take initial measurement - let mut prev_measurement = match Measurement::take(&counters, &mut group) { - Ok(m) => m, - Err(e) => { - eprintln!("Failed to take initial measurement: {}", e); - thread::sleep(Duration::from_millis(100)); - continue; - } - }; + let mut prev_measurement = + match Measurement::take(&counters, &mut group, cpu_to_monitor as u32) { + Ok(m) => m, + Err(e) => { + eprintln!("Failed to take initial measurement: {}", e); + thread::sleep(Duration::from_millis(100)); + continue; + } + }; // println!("Took initial measurement"); @@ -369,23 +353,28 @@ fn run_measurement_loop(log_path: String, mode: Mode, cpu_to_monitor: i32) -> Re thread::sleep(Duration::from_millis(LOG_INTERVAL_MS)); // Take current measurement - let curr_measurement = match Measurement::take(&counters, &mut group) { - Ok(m) => m, - Err(e) => { - eprintln!("Failed to take measurement in round {}: {}", round, e); - continue; - } - }; + let curr_measurement = + match Measurement::take(&counters, &mut group, cpu_to_monitor as u32) { + Ok(m) => m, + Err(e) => { + eprintln!("Failed to take measurement in round {}: {}", round, e); + continue; + } + }; // Calculate difference and write to CSV - let diff = curr_measurement.diff(&prev_measurement); + let mut diff = curr_measurement.diff(&prev_measurement); // println!( // "Measurement diff: duration={}ms, energy={}J", // diff.duration_ms, diff.energy_delta // ); + diff.e_core = if is_e_core { 1 } else { 0 }; - if let Err(e) = diff.write_csv_record(&mut csv_writer) { - eprintln!("Failed to write CSV record: {}", e); + // We have to throw away the first few measurements after changing from one core to the other to avoid noise from tasks executing on both cores at the same time + if round >= 250 { + if let Err(e) = diff.write_csv_record(&mut csv_writer) { + eprintln!("Failed to write CSV record: {}", e); + } } // Current becomes previous for next iteration @@ -407,6 +396,7 @@ fn initialize_csv_writer( let mut header = vec![ "timestamp".to_string(), "duration_ms".to_string(), + "is_e_core".to_string(), "package_power_j".to_string(), "cpu_frequency_mhz".to_string(), ]; @@ -459,10 +449,10 @@ fn define_available_events() -> Vec<(String, Event)> { "cache_misses".to_string(), Event::Hardware(Hardware::CACHE_MISSES), ), - ( - "branch_instructions".to_string(), - Event::Hardware(Hardware::BRANCH_INSTRUCTIONS), - ), + // ( + // "branch_instructions".to_string(), + // Event::Hardware(Hardware::BRANCH_INSTRUCTIONS), + // ), ( "branch_misses".to_string(), Event::Hardware(Hardware::BRANCH_MISSES), @@ -471,10 +461,10 @@ fn define_available_events() -> Vec<(String, Event)> { "ref_cpu_cycles".to_string(), Event::Hardware(Hardware::REF_CPU_CYCLES), ), - ( - "task_clock".to_string(), - Event::Software(Software::TASK_CLOCK), - ), + // ( + // "task_clock".to_string(), + // Event::Software(Software::TASK_CLOCK), + // ), // ( // "stalled-cycles-frontend".to_string(), // Event::Hardware(Hardware::STALLED_CYCLES_FRONTEND), diff --git a/src/energy.rs b/src/energy.rs index 35ead3b..6692a63 100644 --- a/src/energy.rs +++ b/src/energy.rs @@ -9,7 +9,7 @@ use std::ops::RangeInclusive; use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64}; use std::sync::{mpsc, Arc, RwLock}; use std::thread; -use std::time::Duration; +use std::time::{Duration, Instant}; use crate::freq::FrequencyKHZ; use crate::socket; @@ -18,6 +18,9 @@ use crate::Pid; pub use budget::BudgetPolicy; pub use trackers::{KernelDriver, PerfEstimator}; +const IDLE_CONSUMPTION_W: f64 = 7.; +const UPDATE_INTERVAL_MS: u64 = 3; + pub enum Request { NewTask(Pid, Arc<TaskInfo>), RemoveTask(Pid), @@ -85,8 +88,11 @@ pub struct EnergyService { shared_cpu_current_frequencies: Arc<RwLock<Vec<FrequencyKHZ>>>, rapl_offset: f64, old_rapl: f64, + system_energy: f64, bias: f64, offset: f64, + graveyard: Vec<i32>, + last_measurement: Instant, } impl EnergyService { @@ -112,22 +118,31 @@ impl EnergyService { shared_cpu_current_frequencies, rapl_offset: rapl::read_package_energy().unwrap(), old_rapl: 0., + system_energy: 0., bias: 1., offset: 0., + graveyard: Vec::with_capacity(100), + last_measurement: Instant::now(), } } pub fn run(mut self) { thread::spawn(move || { + let mut i = 0; loop { + i += 1; // Process any incoming requests self.handle_requests(); - // Update energy measurements - self.update_measurements(); + if i % 30 == 0 { + // Update energy measurements + self.update_measurements(); + + self.clear_graveyeard(); - // Calculate and update budgets - self.update_budgets(); + // Calculate and update budgets + self.update_budgets(); + } // Sleep for update interval thread::sleep(self.update_interval); @@ -150,11 +165,17 @@ impl EnergyService { info.task_info.set_budget(old_budget); return; } - self.estimator.start_trace( - pid as u64, - task_info.read_cpu(), - task_info.is_running_on_e_core(), - ); + if self + .estimator + .start_trace( + pid as u64, + task_info.read_cpu(), + task_info.is_running_on_e_core(), + ) + .is_err() + { + return; + } let parent = (|| { let process = procfs::process::Process::new(pid)?; process.stat().map(|stat| stat.ppid) @@ -179,10 +200,7 @@ impl EnergyService { if procfs::process::Process::new(pid).is_ok() { return; } - - self.estimator.stop_trace(pid as u64); - self.process_info.write().unwrap().remove(&pid); - self.process_info.write().unwrap().remove(&pid); + self.graveyard.push(pid); } } } @@ -214,22 +232,28 @@ impl EnergyService { } } } + let elapsed = self.last_measurement.elapsed(); + self.last_measurement = Instant::now(); if let Some(init) = self.process_info.write().unwrap().get_mut(&1) { let rapl = rapl::read_package_energy().unwrap() - self.rapl_offset; let rapl_diff = rapl - self.old_rapl; - let est_diff = init.tree_energy - old_energy; - if est_diff < 0.1 { - self.offset = (self.offset + (rapl_diff - est_diff)) * 0.5; - } + let idle_consumption = elapsed.as_secs_f64() * IDLE_CONSUMPTION_W; + let est_diff = init.tree_energy - old_energy + idle_consumption; self.old_rapl = rapl; - init.tree_energy = init.tree_energy + self.offset; - let offset_bias = (rapl / init.tree_energy).clamp(0.1, 2.); - let diff_bias = (rapl_diff / est_diff).clamp(0.1, 2.); - let current_bias = (offset_bias + diff_bias) * 0.5; - self.bias = (self.bias * ((1. / 3.) * current_bias + (2. / 3.))).clamp(0.1, 20.); + // let offset_bias = (rapl / (init.tree_energy + idle_consumption)).clamp(0.1, 2.); + let current_bias = if init.tree_energy - old_energy > idle_consumption * 0.5 { + (rapl_diff / est_diff).clamp(0.1, 2.) + } else { + 1. + }; + // let current_bias = (offset_bias + diff_bias) * 0.5; + let alpha: f64 = 10. * elapsed.as_secs_f64().recip(); + self.bias = (self.bias * (alpha.recip() * current_bias + ((alpha - 1.) / alpha))) + .clamp(0.1, 5.); + self.system_energy += est_diff; println!( - "Energy estimation: {:.1} rapl: {:.1}, est diff: {:.1} rapl diff: {:.1}", - init.tree_energy, rapl, est_diff, rapl_diff, + "Energy estimation: {:.1} rapl: {:.1}, est diff: {:.1} rapl diff: {:.1}, bias: {:.1}", + self.system_energy, rapl, est_diff, rapl_diff, self.bias, ); } } @@ -248,6 +272,14 @@ impl EnergyService { } } + fn clear_graveyeard(&mut self) { + for pid in self.graveyard.drain(..) { + self.estimator.stop_trace(pid as u64); + self.active_processes.remove(&pid); + self.process_info.write().unwrap().remove(&pid); + } + } + // Accessor methods for BudgetPolicy pub fn active_processes(&self) -> &BTreeSet<Pid> { &self.active_processes @@ -306,7 +338,7 @@ pub fn start_energy_service( budget_policy, process_info.clone(), request_receiver, - Duration::from_millis(50), // 50ms update interval + Duration::from_millis(UPDATE_INTERVAL_MS), // 50ms update interval shared_cpu_frequency_ranges, shared_policy_frequency_ranges, shared_cpu_current_frequencies, diff --git a/src/energy/estimator.rs b/src/energy/estimator.rs index fbab744..03034db 100644 --- a/src/energy/estimator.rs +++ b/src/energy/estimator.rs @@ -1,5 +1,5 @@ pub trait Estimator: Send + 'static { - fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool); + fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool) -> Result<(), ()>; fn stop_trace(&mut self, pid: u64); fn update_information(&mut self, pid: u64, cpu: i32, is_ecore: bool); fn read_consumption(&mut self, pid: u64) -> Option<f64>; diff --git a/src/energy/trackers/kernel.rs b/src/energy/trackers/kernel.rs index f42bb16..a0178a4 100644 --- a/src/energy/trackers/kernel.rs +++ b/src/energy/trackers/kernel.rs @@ -30,8 +30,9 @@ const STOP_TRACE: Ioctl<Write, &u64> = unsafe { PERF_MON.write(0x81) }; const READ_POWER: Ioctl<WriteRead, &u64> = unsafe { PERF_MON.write_read(0x82) }; impl Estimator for KernelDriver { - fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool) { + fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool) -> Result<(), ()> { let _ = START_TRACE.ioctl(&mut self.file, &pid); + Ok(()) } fn stop_trace(&mut self, pid: u64) { diff --git a/src/energy/trackers/mock.rs b/src/energy/trackers/mock.rs index d9ede50..cd08c34 100644 --- a/src/energy/trackers/mock.rs +++ b/src/energy/trackers/mock.rs @@ -3,7 +3,9 @@ use crate::energy::estimator::Estimator; pub struct MockEstimator; impl Estimator for MockEstimator { - fn start_trace(&mut self, _pid: u64, _cpu: i32, _running_on_e_core: bool) {} + fn start_trace(&mut self, _pid: u64, _cpu: i32, _running_on_e_core: bool) -> Result<(), ()> { + Ok(()) + } fn stop_trace(&mut self, _pid: u64) {} diff --git a/src/energy/trackers/perf.rs b/src/energy/trackers/perf.rs index e59057d..38cefe9 100644 --- a/src/energy/trackers/perf.rs +++ b/src/energy/trackers/perf.rs @@ -22,8 +22,10 @@ pub struct PerfEstimator { impl PerfEstimator { pub fn new(shared_cpu_current_frequencies: Arc<RwLock<Vec<FrequencyKHZ>>>) -> Self { - let model_p = crate::model::load_model_p(); - let model_e = crate::model::load_model_e(); + // let model_p = crate::model::load_model("perf_pcore.pt"); + let model_p = crate::model::load_model("perf.pt"); + let model_e = crate::model::load_model("perf.pt"); + // let model_e = crate::model::load_model("perf_ecore.pt"); Self { registry: Default::default(), model_p, @@ -51,22 +53,18 @@ static EVENT_TYPES_P: &[Event] = &[ Event::Hardware(Hardware::CPU_CYCLES), Event::Hardware(Hardware::INSTRUCTIONS), Event::Hardware(Hardware::REF_CPU_CYCLES), - Event::Software(Software::TASK_CLOCK), ]; -//TODO: use correct counter static EVENT_TYPES_E: &[Event] = &[ - Event::Hardware(Hardware::BRANCH_INSTRUCTIONS), Event::Hardware(Hardware::BRANCH_MISSES), Event::Hardware(Hardware::CACHE_MISSES), Event::Hardware(Hardware::CACHE_REFERENCES), Event::Hardware(Hardware::CPU_CYCLES), Event::Hardware(Hardware::INSTRUCTIONS), Event::Hardware(Hardware::REF_CPU_CYCLES), - Event::Software(Software::TASK_CLOCK), ]; impl Estimator for PerfEstimator { - fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool) { + fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool) -> Result<(), ()> { let mut group = match Group::new_with_pid_and_cpu(pid as i32, -1) { Ok(counters) => counters, Err(e) => { @@ -74,11 +72,12 @@ impl Estimator for PerfEstimator { "Failed to create performance counter group for PID {}: {}", pid, e ); - return; + return Err(()); } }; - let counters: Result<Vec<_>, _> = if running_on_e_core { + let counters: Result<Vec<_>, _> = if running_on_e_core || true { + // println!("starting e core counter"); EVENT_TYPES_E } else { EVENT_TYPES_P @@ -100,28 +99,30 @@ impl Estimator for PerfEstimator { "Failed to create performance counter group for PID {}: {}", pid, e ); - return; + return Err(()); } }; if let Err(e) = group.enable() { eprintln!("Failed to enable performance counters: {}", e); - return; + return Err(()); } if let Err(e) = group.reset() { eprintln!("Failed to reset performance counters: {}", e); - return; + return Err(()); } + let old_time = group.read().unwrap().time_running(); let counters = Counters { counters, group, - old_time: 0, + old_time, old_total_energy: 0., cpu, running_on_e_core, }; self.registry.insert(pid, counters); + Ok(()) } fn stop_trace(&mut self, pid: u64) { @@ -132,18 +133,21 @@ impl Estimator for PerfEstimator { let mut core_type_changed = false; if let Some(info) = self.registry.get_mut(&pid) { info.cpu = cpu; + info.running_on_e_core = is_ecore; core_type_changed = is_ecore != info.running_on_e_core; } else { eprintln!("Tried to update an unknown task") } if core_type_changed { - self.stop_trace(pid); - self.stop_trace(pid); + // println!("migrating task to {}", cpu); + // self.stop_trace(pid); + // self.start_trace(pid, cpu, is_ecore); } } fn read_consumption(&mut self, pid: u64) -> Option<f64> { let Some(counters) = self.registry.get_mut(&pid) else { + println!("did not find counters for {pid}"); return None; }; @@ -154,21 +158,19 @@ impl Estimator for PerfEstimator { return None; } }; - let num_counter = counters.counters.len(); - let task_clock = counts[&counters.counters[num_counter - 1]]; - - if task_clock == 0 { + let time_running_ns = counts.time_running(); + if time_running_ns - counters.old_time == 0 || counts.iter().next().unwrap().1 == &0 { return None; } - let time_running_ns = counts.time_running(); let correction_factor = 10_000_000. / (time_running_ns - counters.old_time) as f64; counters.old_time = time_running_ns; let mut values = vec![ + //if counters.running_on_e_core { 1. } else { 0. }, (self.shared_cpu_current_frequencies.read().unwrap()[counters.cpu as usize] / 1000) as f64, ]; - for ty in counters.counters.iter().take(num_counter - 1) { + for ty in counters.counters.iter() { let count: u64 = counts[&ty]; values.push((count as f64) * correction_factor); } @@ -181,6 +183,9 @@ impl Estimator for PerfEstimator { .forward(Tensor::from_floats(&values.as_slice()[0..], &self.device)); let energy = result.into_scalar() as f64; + if counters.running_on_e_core { + // dbg!(energy); + } counters.old_total_energy += energy / correction_factor; counters.group.reset().unwrap(); Some(energy / correction_factor) diff --git a/src/main.rs b/src/main.rs index bb53fc5..9f3528c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,7 +15,9 @@ mod socket; mod bpf; use anyhow::Result; +use burn::tensor::Tensor; use clap::{Arg, ArgAction, Command}; +use model::load_model; use scheduler::Scheduler; use std::mem::MaybeUninit; @@ -48,6 +50,21 @@ fn main() -> Result<()> { ) .get_matches(); + let device = Default::default(); + let model = load_model("perf.pt"); + let tensor = Tensor::from_floats( + [ + 800., 90678., 54734., 153646., 20354478., 40948418., + 89103105., + //5200., 148947., 322426., 498965., 62340773., 144451046., 41976480., + ], + &device, + ); + let result = model.forward(tensor); + let energy: f32 = result.into_scalar(); + println!("energy: {energy}"); + // panic!(); + let power_cap = *matches.get_one::<u64>("power_cap").unwrap_or(&u64::MAX); let use_mocking = matches.get_flag("mock"); let benchmark = matches.get_one::<String>("benchmark"); diff --git a/src/model.rs b/src/model.rs index fbb6e0b..6dbde1f 100644 --- a/src/model.rs +++ b/src/model.rs @@ -50,22 +50,10 @@ impl<B: Backend> Net<B> { } /// Load the p core model from the file in your source code (not in build.rs or script). -pub fn load_model_p() -> Net<ArrayBackend> { +pub fn load_model(path: &str) -> Net<ArrayBackend> { let device = Default::default(); let record: NetRecord<ArrayBackend> = PyTorchFileRecorder::<FullPrecisionSettings>::default() - .load("./perf.pt".into(), &device) - .expect("Failed to decode state"); - - Net::<ArrayBackend>::init(&device).load_record(record) -} - -/// Load the e core model from the file in your source code (not in build.rs or script). -pub fn load_model_e() -> Net<ArrayBackend> { - //TODO: load e model - println!("Falling back to p model"); - let device = Default::default(); - let record: NetRecord<ArrayBackend> = PyTorchFileRecorder::<FullPrecisionSettings>::default() - .load("./perf.pt".into(), &device) + .load(path.into(), &device) .expect("Failed to decode state"); Net::<ArrayBackend>::init(&device).load_record(record) |