summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDennis Kobert <dennis@kobert.dev>2025-04-02 10:20:43 +0200
committerDennis Kobert <dennis@kobert.dev>2025-04-02 10:20:43 +0200
commitc8c05d29419822aff3554af788e910ec69267406 (patch)
treef978c5f3720c1075c2cc9a12fc029e874e3a1c7c
parent1c31c2d7068737af583d76ec0e7dc12125a5c34d (diff)
Implement energy model for e cores
-rw-r--r--src/benchmark.rs126
-rw-r--r--src/energy.rs84
-rw-r--r--src/energy/estimator.rs2
-rw-r--r--src/energy/trackers/kernel.rs3
-rw-r--r--src/energy/trackers/mock.rs4
-rw-r--r--src/energy/trackers/perf.rs47
-rw-r--r--src/main.rs17
-rw-r--r--src/model.rs16
8 files changed, 167 insertions, 132 deletions
diff --git a/src/benchmark.rs b/src/benchmark.rs
index 77202a8..4876159 100644
--- a/src/benchmark.rs
+++ b/src/benchmark.rs
@@ -10,20 +10,20 @@ use perf_event::{
use rand::seq::IteratorRandom;
use scx_utils::Topology;
use scx_utils::UserExitInfo;
-use std::fs::File;
use std::mem::MaybeUninit;
use std::process;
use std::thread;
use std::time::{Duration, Instant};
use std::{collections::HashMap, i32, ops::Range};
+use std::{fs::File, sync::atomic::AtomicI32};
-const SLICE_US: u64 = 50000;
-const LOG_INTERVAL_MS: u64 = 10; // Log every 1 second
- // const RESHUFFLE_ROUNDS: usize = 5; // Number of rounds before reshuffling counters
-const RESHUFFLE_ROUNDS: usize = 1; // Number of rounds before reshuffling counters
+const SLICE_US: u64 = 5000;
+const LOG_INTERVAL_MS: u64 = 10;
+const RESHUFFLE_ROUNDS: usize = 1000; // Number of rounds before changing cpu
const MAX_COUNTERS_AT_ONCE_P_CORE: usize = 7;
const MAX_COUNTERS_AT_ONCE_E_CORE: usize = 8;
type Pid = i32;
+static CPU: AtomicI32 = AtomicI32::new(0);
pub struct BenchmarkScheduler<'a> {
bpf: BpfScheduler<'a>,
@@ -69,7 +69,7 @@ impl Measurement {
}
// Take a measurement with the given counter group
- fn take(counters: &[(String, Counter)], group: &mut Group) -> Result<Self> {
+ fn take(counters: &[(String, Counter)], group: &mut Group, cpu_id: u32) -> Result<Self> {
let mut measurement = Self::new();
// Read energy
@@ -77,7 +77,7 @@ impl Measurement {
measurement.energy = rapl::read_package_energy().ok();
// Read CPU frequency
- measurement.frequency = read_cpu_frequency(0);
+ measurement.frequency = read_cpu_frequency(cpu_id);
// Read performance counters
let counts = group.read()?;
@@ -118,6 +118,7 @@ impl Measurement {
energy_delta,
frequency: self.frequency,
counter_deltas: self.counter_values.clone(),
+ e_core: 0,
}
}
}
@@ -125,6 +126,7 @@ impl Measurement {
// Represents the difference between two measurements
struct MeasurementDiff {
timestamp: Instant,
+ e_core: u8,
duration_ms: u64,
energy_delta: f64,
frequency: Option<f64>,
@@ -138,6 +140,7 @@ impl MeasurementDiff {
let mut record = vec![
self.timestamp.elapsed().as_secs_f64().to_string(),
self.duration_ms.to_string(),
+ self.e_core.to_string(),
self.energy_delta.to_string(),
self.frequency
.map(|f| f.to_string())
@@ -201,23 +204,12 @@ impl<'a> BenchmarkScheduler<'a> {
while let Ok(Some(task)) = self.bpf.dequeue_task() {
let mut dispatched_task = DispatchedTask::new(&task);
- match self.mode {
- // If it's our own process, schedule it to core 1
- Mode::PCores => {
- if task.pid == self.own_pid {
- dispatched_task.cpu = self.p_cores.start + 1;
- } else {
- // Schedule all other tasks on core 0
- dispatched_task.cpu = self.p_cores.start;
- }
- }
- Mode::ECores => {
- if task.pid == self.own_pid {
- dispatched_task.cpu = self.e_cores.start + 1;
- } else {
- dispatched_task.cpu = self.e_cores.start;
- }
- }
+ let cpu = CPU.load(std::sync::atomic::Ordering::Relaxed);
+ if task.pid == self.own_pid {
+ dispatched_task.cpu = cpu + 1;
+ } else {
+ // Schedule all other tasks on core 0
+ dispatched_task.cpu = cpu;
}
dispatched_task.slice_ns = SLICE_US;
@@ -240,15 +232,7 @@ impl<'a> BenchmarkScheduler<'a> {
let e_cores = self.e_cores.clone();
let p_cores = self.p_cores.clone();
thread::spawn(move || {
- if let Err(e) = run_measurement_loop(
- log_path,
- mode,
- if mode == Mode::PCores {
- p_cores.start
- } else {
- e_cores.start
- },
- ) {
+ if let Err(e) = run_measurement_loop(log_path, mode, p_cores.start, e_cores.start) {
eprintln!("Measurement thread error: {:?}", e);
}
})
@@ -269,7 +253,7 @@ impl<'a> BenchmarkScheduler<'a> {
}
// Main measurement loop
-fn run_measurement_loop(log_path: String, mode: Mode, cpu_to_monitor: i32) -> Result<()> {
+fn run_measurement_loop(log_path: String, mode: Mode, p_core: i32, e_core: i32) -> Result<()> {
// Define available hardware counters
let available_events = define_available_events();
@@ -279,12 +263,16 @@ fn run_measurement_loop(log_path: String, mode: Mode, cpu_to_monitor: i32) -> Re
let mut rng = rand::rng();
let mut round_counter = 0;
+ let mut cpu_to_monitor = p_core;
println!("Monitoring: {cpu_to_monitor}");
// Main measurement loop
loop {
// println!("Starting new counter group (round {})", round_counter);
round_counter += 1;
+ let is_e_core = round_counter % 2 == 0;
+ cpu_to_monitor = if is_e_core { p_core } else { e_core };
+ CPU.store(cpu_to_monitor, std::sync::atomic::Ordering::Relaxed);
// Create a new perf group
let mut group = match Group::new_with_pid_and_cpu(-1, cpu_to_monitor) {
@@ -297,14 +285,9 @@ fn run_measurement_loop(log_path: String, mode: Mode, cpu_to_monitor: i32) -> Re
};
// Select random subset of counters
- let selected_events = available_events.iter().choose_multiple(
- &mut rng,
- if mode == Mode::PCores {
- MAX_COUNTERS_AT_ONCE_P_CORE
- } else {
- MAX_COUNTERS_AT_ONCE_E_CORE
- },
- );
+ let selected_events = available_events
+ .iter()
+ .choose_multiple(&mut rng, MAX_COUNTERS_AT_ONCE_P_CORE);
//let selected_events = available_events[0..MAX_COUNTERS_AT_ONCE_E_CORE].iter();
@@ -350,14 +333,15 @@ fn run_measurement_loop(log_path: String, mode: Mode, cpu_to_monitor: i32) -> Re
// );
// Take initial measurement
- let mut prev_measurement = match Measurement::take(&counters, &mut group) {
- Ok(m) => m,
- Err(e) => {
- eprintln!("Failed to take initial measurement: {}", e);
- thread::sleep(Duration::from_millis(100));
- continue;
- }
- };
+ let mut prev_measurement =
+ match Measurement::take(&counters, &mut group, cpu_to_monitor as u32) {
+ Ok(m) => m,
+ Err(e) => {
+ eprintln!("Failed to take initial measurement: {}", e);
+ thread::sleep(Duration::from_millis(100));
+ continue;
+ }
+ };
// println!("Took initial measurement");
@@ -369,23 +353,28 @@ fn run_measurement_loop(log_path: String, mode: Mode, cpu_to_monitor: i32) -> Re
thread::sleep(Duration::from_millis(LOG_INTERVAL_MS));
// Take current measurement
- let curr_measurement = match Measurement::take(&counters, &mut group) {
- Ok(m) => m,
- Err(e) => {
- eprintln!("Failed to take measurement in round {}: {}", round, e);
- continue;
- }
- };
+ let curr_measurement =
+ match Measurement::take(&counters, &mut group, cpu_to_monitor as u32) {
+ Ok(m) => m,
+ Err(e) => {
+ eprintln!("Failed to take measurement in round {}: {}", round, e);
+ continue;
+ }
+ };
// Calculate difference and write to CSV
- let diff = curr_measurement.diff(&prev_measurement);
+ let mut diff = curr_measurement.diff(&prev_measurement);
// println!(
// "Measurement diff: duration={}ms, energy={}J",
// diff.duration_ms, diff.energy_delta
// );
+ diff.e_core = if is_e_core { 1 } else { 0 };
- if let Err(e) = diff.write_csv_record(&mut csv_writer) {
- eprintln!("Failed to write CSV record: {}", e);
+ // We have to throw away the first few measurements after changing from one core to the other to avoid noise from tasks executing on both cores at the same time
+ if round >= 250 {
+ if let Err(e) = diff.write_csv_record(&mut csv_writer) {
+ eprintln!("Failed to write CSV record: {}", e);
+ }
}
// Current becomes previous for next iteration
@@ -407,6 +396,7 @@ fn initialize_csv_writer(
let mut header = vec![
"timestamp".to_string(),
"duration_ms".to_string(),
+ "is_e_core".to_string(),
"package_power_j".to_string(),
"cpu_frequency_mhz".to_string(),
];
@@ -459,10 +449,10 @@ fn define_available_events() -> Vec<(String, Event)> {
"cache_misses".to_string(),
Event::Hardware(Hardware::CACHE_MISSES),
),
- (
- "branch_instructions".to_string(),
- Event::Hardware(Hardware::BRANCH_INSTRUCTIONS),
- ),
+ // (
+ // "branch_instructions".to_string(),
+ // Event::Hardware(Hardware::BRANCH_INSTRUCTIONS),
+ // ),
(
"branch_misses".to_string(),
Event::Hardware(Hardware::BRANCH_MISSES),
@@ -471,10 +461,10 @@ fn define_available_events() -> Vec<(String, Event)> {
"ref_cpu_cycles".to_string(),
Event::Hardware(Hardware::REF_CPU_CYCLES),
),
- (
- "task_clock".to_string(),
- Event::Software(Software::TASK_CLOCK),
- ),
+ // (
+ // "task_clock".to_string(),
+ // Event::Software(Software::TASK_CLOCK),
+ // ),
// (
// "stalled-cycles-frontend".to_string(),
// Event::Hardware(Hardware::STALLED_CYCLES_FRONTEND),
diff --git a/src/energy.rs b/src/energy.rs
index 35ead3b..6692a63 100644
--- a/src/energy.rs
+++ b/src/energy.rs
@@ -9,7 +9,7 @@ use std::ops::RangeInclusive;
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64};
use std::sync::{mpsc, Arc, RwLock};
use std::thread;
-use std::time::Duration;
+use std::time::{Duration, Instant};
use crate::freq::FrequencyKHZ;
use crate::socket;
@@ -18,6 +18,9 @@ use crate::Pid;
pub use budget::BudgetPolicy;
pub use trackers::{KernelDriver, PerfEstimator};
+const IDLE_CONSUMPTION_W: f64 = 7.;
+const UPDATE_INTERVAL_MS: u64 = 3;
+
pub enum Request {
NewTask(Pid, Arc<TaskInfo>),
RemoveTask(Pid),
@@ -85,8 +88,11 @@ pub struct EnergyService {
shared_cpu_current_frequencies: Arc<RwLock<Vec<FrequencyKHZ>>>,
rapl_offset: f64,
old_rapl: f64,
+ system_energy: f64,
bias: f64,
offset: f64,
+ graveyard: Vec<i32>,
+ last_measurement: Instant,
}
impl EnergyService {
@@ -112,22 +118,31 @@ impl EnergyService {
shared_cpu_current_frequencies,
rapl_offset: rapl::read_package_energy().unwrap(),
old_rapl: 0.,
+ system_energy: 0.,
bias: 1.,
offset: 0.,
+ graveyard: Vec::with_capacity(100),
+ last_measurement: Instant::now(),
}
}
pub fn run(mut self) {
thread::spawn(move || {
+ let mut i = 0;
loop {
+ i += 1;
// Process any incoming requests
self.handle_requests();
- // Update energy measurements
- self.update_measurements();
+ if i % 30 == 0 {
+ // Update energy measurements
+ self.update_measurements();
+
+ self.clear_graveyeard();
- // Calculate and update budgets
- self.update_budgets();
+ // Calculate and update budgets
+ self.update_budgets();
+ }
// Sleep for update interval
thread::sleep(self.update_interval);
@@ -150,11 +165,17 @@ impl EnergyService {
info.task_info.set_budget(old_budget);
return;
}
- self.estimator.start_trace(
- pid as u64,
- task_info.read_cpu(),
- task_info.is_running_on_e_core(),
- );
+ if self
+ .estimator
+ .start_trace(
+ pid as u64,
+ task_info.read_cpu(),
+ task_info.is_running_on_e_core(),
+ )
+ .is_err()
+ {
+ return;
+ }
let parent = (|| {
let process = procfs::process::Process::new(pid)?;
process.stat().map(|stat| stat.ppid)
@@ -179,10 +200,7 @@ impl EnergyService {
if procfs::process::Process::new(pid).is_ok() {
return;
}
-
- self.estimator.stop_trace(pid as u64);
- self.process_info.write().unwrap().remove(&pid);
- self.process_info.write().unwrap().remove(&pid);
+ self.graveyard.push(pid);
}
}
}
@@ -214,22 +232,28 @@ impl EnergyService {
}
}
}
+ let elapsed = self.last_measurement.elapsed();
+ self.last_measurement = Instant::now();
if let Some(init) = self.process_info.write().unwrap().get_mut(&1) {
let rapl = rapl::read_package_energy().unwrap() - self.rapl_offset;
let rapl_diff = rapl - self.old_rapl;
- let est_diff = init.tree_energy - old_energy;
- if est_diff < 0.1 {
- self.offset = (self.offset + (rapl_diff - est_diff)) * 0.5;
- }
+ let idle_consumption = elapsed.as_secs_f64() * IDLE_CONSUMPTION_W;
+ let est_diff = init.tree_energy - old_energy + idle_consumption;
self.old_rapl = rapl;
- init.tree_energy = init.tree_energy + self.offset;
- let offset_bias = (rapl / init.tree_energy).clamp(0.1, 2.);
- let diff_bias = (rapl_diff / est_diff).clamp(0.1, 2.);
- let current_bias = (offset_bias + diff_bias) * 0.5;
- self.bias = (self.bias * ((1. / 3.) * current_bias + (2. / 3.))).clamp(0.1, 20.);
+ // let offset_bias = (rapl / (init.tree_energy + idle_consumption)).clamp(0.1, 2.);
+ let current_bias = if init.tree_energy - old_energy > idle_consumption * 0.5 {
+ (rapl_diff / est_diff).clamp(0.1, 2.)
+ } else {
+ 1.
+ };
+ // let current_bias = (offset_bias + diff_bias) * 0.5;
+ let alpha: f64 = 10. * elapsed.as_secs_f64().recip();
+ self.bias = (self.bias * (alpha.recip() * current_bias + ((alpha - 1.) / alpha)))
+ .clamp(0.1, 5.);
+ self.system_energy += est_diff;
println!(
- "Energy estimation: {:.1} rapl: {:.1}, est diff: {:.1} rapl diff: {:.1}",
- init.tree_energy, rapl, est_diff, rapl_diff,
+ "Energy estimation: {:.1} rapl: {:.1}, est diff: {:.1} rapl diff: {:.1}, bias: {:.1}",
+ self.system_energy, rapl, est_diff, rapl_diff, self.bias,
);
}
}
@@ -248,6 +272,14 @@ impl EnergyService {
}
}
+ fn clear_graveyeard(&mut self) {
+ for pid in self.graveyard.drain(..) {
+ self.estimator.stop_trace(pid as u64);
+ self.active_processes.remove(&pid);
+ self.process_info.write().unwrap().remove(&pid);
+ }
+ }
+
// Accessor methods for BudgetPolicy
pub fn active_processes(&self) -> &BTreeSet<Pid> {
&self.active_processes
@@ -306,7 +338,7 @@ pub fn start_energy_service(
budget_policy,
process_info.clone(),
request_receiver,
- Duration::from_millis(50), // 50ms update interval
+ Duration::from_millis(UPDATE_INTERVAL_MS), // 50ms update interval
shared_cpu_frequency_ranges,
shared_policy_frequency_ranges,
shared_cpu_current_frequencies,
diff --git a/src/energy/estimator.rs b/src/energy/estimator.rs
index fbab744..03034db 100644
--- a/src/energy/estimator.rs
+++ b/src/energy/estimator.rs
@@ -1,5 +1,5 @@
pub trait Estimator: Send + 'static {
- fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool);
+ fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool) -> Result<(), ()>;
fn stop_trace(&mut self, pid: u64);
fn update_information(&mut self, pid: u64, cpu: i32, is_ecore: bool);
fn read_consumption(&mut self, pid: u64) -> Option<f64>;
diff --git a/src/energy/trackers/kernel.rs b/src/energy/trackers/kernel.rs
index f42bb16..a0178a4 100644
--- a/src/energy/trackers/kernel.rs
+++ b/src/energy/trackers/kernel.rs
@@ -30,8 +30,9 @@ const STOP_TRACE: Ioctl<Write, &u64> = unsafe { PERF_MON.write(0x81) };
const READ_POWER: Ioctl<WriteRead, &u64> = unsafe { PERF_MON.write_read(0x82) };
impl Estimator for KernelDriver {
- fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool) {
+ fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool) -> Result<(), ()> {
let _ = START_TRACE.ioctl(&mut self.file, &pid);
+ Ok(())
}
fn stop_trace(&mut self, pid: u64) {
diff --git a/src/energy/trackers/mock.rs b/src/energy/trackers/mock.rs
index d9ede50..cd08c34 100644
--- a/src/energy/trackers/mock.rs
+++ b/src/energy/trackers/mock.rs
@@ -3,7 +3,9 @@ use crate::energy::estimator::Estimator;
pub struct MockEstimator;
impl Estimator for MockEstimator {
- fn start_trace(&mut self, _pid: u64, _cpu: i32, _running_on_e_core: bool) {}
+ fn start_trace(&mut self, _pid: u64, _cpu: i32, _running_on_e_core: bool) -> Result<(), ()> {
+ Ok(())
+ }
fn stop_trace(&mut self, _pid: u64) {}
diff --git a/src/energy/trackers/perf.rs b/src/energy/trackers/perf.rs
index e59057d..38cefe9 100644
--- a/src/energy/trackers/perf.rs
+++ b/src/energy/trackers/perf.rs
@@ -22,8 +22,10 @@ pub struct PerfEstimator {
impl PerfEstimator {
pub fn new(shared_cpu_current_frequencies: Arc<RwLock<Vec<FrequencyKHZ>>>) -> Self {
- let model_p = crate::model::load_model_p();
- let model_e = crate::model::load_model_e();
+ // let model_p = crate::model::load_model("perf_pcore.pt");
+ let model_p = crate::model::load_model("perf.pt");
+ let model_e = crate::model::load_model("perf.pt");
+ // let model_e = crate::model::load_model("perf_ecore.pt");
Self {
registry: Default::default(),
model_p,
@@ -51,22 +53,18 @@ static EVENT_TYPES_P: &[Event] = &[
Event::Hardware(Hardware::CPU_CYCLES),
Event::Hardware(Hardware::INSTRUCTIONS),
Event::Hardware(Hardware::REF_CPU_CYCLES),
- Event::Software(Software::TASK_CLOCK),
];
-//TODO: use correct counter
static EVENT_TYPES_E: &[Event] = &[
- Event::Hardware(Hardware::BRANCH_INSTRUCTIONS),
Event::Hardware(Hardware::BRANCH_MISSES),
Event::Hardware(Hardware::CACHE_MISSES),
Event::Hardware(Hardware::CACHE_REFERENCES),
Event::Hardware(Hardware::CPU_CYCLES),
Event::Hardware(Hardware::INSTRUCTIONS),
Event::Hardware(Hardware::REF_CPU_CYCLES),
- Event::Software(Software::TASK_CLOCK),
];
impl Estimator for PerfEstimator {
- fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool) {
+ fn start_trace(&mut self, pid: u64, cpu: i32, running_on_e_core: bool) -> Result<(), ()> {
let mut group = match Group::new_with_pid_and_cpu(pid as i32, -1) {
Ok(counters) => counters,
Err(e) => {
@@ -74,11 +72,12 @@ impl Estimator for PerfEstimator {
"Failed to create performance counter group for PID {}: {}",
pid, e
);
- return;
+ return Err(());
}
};
- let counters: Result<Vec<_>, _> = if running_on_e_core {
+ let counters: Result<Vec<_>, _> = if running_on_e_core || true {
+ // println!("starting e core counter");
EVENT_TYPES_E
} else {
EVENT_TYPES_P
@@ -100,28 +99,30 @@ impl Estimator for PerfEstimator {
"Failed to create performance counter group for PID {}: {}",
pid, e
);
- return;
+ return Err(());
}
};
if let Err(e) = group.enable() {
eprintln!("Failed to enable performance counters: {}", e);
- return;
+ return Err(());
}
if let Err(e) = group.reset() {
eprintln!("Failed to reset performance counters: {}", e);
- return;
+ return Err(());
}
+ let old_time = group.read().unwrap().time_running();
let counters = Counters {
counters,
group,
- old_time: 0,
+ old_time,
old_total_energy: 0.,
cpu,
running_on_e_core,
};
self.registry.insert(pid, counters);
+ Ok(())
}
fn stop_trace(&mut self, pid: u64) {
@@ -132,18 +133,21 @@ impl Estimator for PerfEstimator {
let mut core_type_changed = false;
if let Some(info) = self.registry.get_mut(&pid) {
info.cpu = cpu;
+ info.running_on_e_core = is_ecore;
core_type_changed = is_ecore != info.running_on_e_core;
} else {
eprintln!("Tried to update an unknown task")
}
if core_type_changed {
- self.stop_trace(pid);
- self.stop_trace(pid);
+ // println!("migrating task to {}", cpu);
+ // self.stop_trace(pid);
+ // self.start_trace(pid, cpu, is_ecore);
}
}
fn read_consumption(&mut self, pid: u64) -> Option<f64> {
let Some(counters) = self.registry.get_mut(&pid) else {
+ println!("did not find counters for {pid}");
return None;
};
@@ -154,21 +158,19 @@ impl Estimator for PerfEstimator {
return None;
}
};
- let num_counter = counters.counters.len();
- let task_clock = counts[&counters.counters[num_counter - 1]];
-
- if task_clock == 0 {
+ let time_running_ns = counts.time_running();
+ if time_running_ns - counters.old_time == 0 || counts.iter().next().unwrap().1 == &0 {
return None;
}
- let time_running_ns = counts.time_running();
let correction_factor = 10_000_000. / (time_running_ns - counters.old_time) as f64;
counters.old_time = time_running_ns;
let mut values = vec![
+ //if counters.running_on_e_core { 1. } else { 0. },
(self.shared_cpu_current_frequencies.read().unwrap()[counters.cpu as usize] / 1000)
as f64,
];
- for ty in counters.counters.iter().take(num_counter - 1) {
+ for ty in counters.counters.iter() {
let count: u64 = counts[&ty];
values.push((count as f64) * correction_factor);
}
@@ -181,6 +183,9 @@ impl Estimator for PerfEstimator {
.forward(Tensor::from_floats(&values.as_slice()[0..], &self.device));
let energy = result.into_scalar() as f64;
+ if counters.running_on_e_core {
+ // dbg!(energy);
+ }
counters.old_total_energy += energy / correction_factor;
counters.group.reset().unwrap();
Some(energy / correction_factor)
diff --git a/src/main.rs b/src/main.rs
index bb53fc5..9f3528c 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -15,7 +15,9 @@ mod socket;
mod bpf;
use anyhow::Result;
+use burn::tensor::Tensor;
use clap::{Arg, ArgAction, Command};
+use model::load_model;
use scheduler::Scheduler;
use std::mem::MaybeUninit;
@@ -48,6 +50,21 @@ fn main() -> Result<()> {
)
.get_matches();
+ let device = Default::default();
+ let model = load_model("perf.pt");
+ let tensor = Tensor::from_floats(
+ [
+ 800., 90678., 54734., 153646., 20354478., 40948418.,
+ 89103105.,
+ //5200., 148947., 322426., 498965., 62340773., 144451046., 41976480.,
+ ],
+ &device,
+ );
+ let result = model.forward(tensor);
+ let energy: f32 = result.into_scalar();
+ println!("energy: {energy}");
+ // panic!();
+
let power_cap = *matches.get_one::<u64>("power_cap").unwrap_or(&u64::MAX);
let use_mocking = matches.get_flag("mock");
let benchmark = matches.get_one::<String>("benchmark");
diff --git a/src/model.rs b/src/model.rs
index fbb6e0b..6dbde1f 100644
--- a/src/model.rs
+++ b/src/model.rs
@@ -50,22 +50,10 @@ impl<B: Backend> Net<B> {
}
/// Load the p core model from the file in your source code (not in build.rs or script).
-pub fn load_model_p() -> Net<ArrayBackend> {
+pub fn load_model(path: &str) -> Net<ArrayBackend> {
let device = Default::default();
let record: NetRecord<ArrayBackend> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
- .load("./perf.pt".into(), &device)
- .expect("Failed to decode state");
-
- Net::<ArrayBackend>::init(&device).load_record(record)
-}
-
-/// Load the e core model from the file in your source code (not in build.rs or script).
-pub fn load_model_e() -> Net<ArrayBackend> {
- //TODO: load e model
- println!("Falling back to p model");
- let device = Default::default();
- let record: NetRecord<ArrayBackend> = PyTorchFileRecorder::<FullPrecisionSettings>::default()
- .load("./perf.pt".into(), &device)
+ .load(path.into(), &device)
.expect("Failed to decode state");
Net::<ArrayBackend>::init(&device).load_record(record)