use super::{Message, ResultMessage, RowResult}; use std::collections::{HashMap, HashSet}; use std::sync::mpsc::{channel, Receiver, Sender}; use std::thread::JoinHandle; struct InBuffer { receiver: Receiver, row_requests: HashMap>>, results_requests: HashMap, banned_requests: HashSet, } impl InBuffer { fn new(receiver: Receiver) -> Self { Self { receiver, row_requests: HashMap::new(), results_requests: HashMap::new(), banned_requests: HashSet::new(), } } fn read(&mut self) -> Option> { loop { //println!("{:?}", self.receiver.recv().unwrap()); //continue; match self .receiver .recv() .expect("Channel to Output Daemon broke") { Message::ResultMessage(results) => { if results.data.iter().any(|x| *x != 0) { println!("Horay results!"); if let Some(result_walls) = self.row_requests.get(&results.id) { return Some(Self::calc_results( results.valid_walls().as_ref(), result_walls, )); } else { self.results_requests.insert(results.id, results); } } else { if self.row_requests.remove(&results.id).is_none() { self.banned_requests.insert(results.id); } } } Message::OutputMessage((id, output)) => { if self.banned_requests.remove(&id) { continue; } if let Some(results) = self.results_requests.get(&id) { return Some(Self::calc_results( results.valid_walls().as_ref(), output.as_ref(), )); } else { self.row_requests.insert(id, output); } } Message::CpuDone => { return None; } Message::Terminate => { return None; } _ => { println!("Invalid MessageType"); } } } } fn calc_results(res_req: &[Vec], row_req: &[Vec]) -> Vec { let mut out = Vec::new(); for (rows, perms) in row_req.iter().zip(res_req.iter()) { for p in perms { let mut new = rows.clone(); new.push(*p); out.push(RowResult::new(new)); } } out } } pub struct Output { input: InBuffer, permutations: Vec>, permutations_mask: Vec, results: HashSet, result_sender: Sender, } impl Output { pub fn launch_sevice( permutations: &[Vec], permutations_mask: &[u64], result_sender: Sender, ) -> (Sender, JoinHandle<()>) { let (sender, receiver) = channel(); let input = InBuffer::new(receiver); let output = Self { input, permutations: permutations.into(), permutations_mask: permutations_mask.into(), results: HashSet::new(), result_sender, }; ( sender, std::thread::Builder::new() .name("GPU Output Deamon".into()) .spawn(move || { output.run(); }) .unwrap(), ) } fn run(mut self) { loop { if let Some(walls) = self.input.read() { for wall in walls { if !self.results.contains(&wall) { wall.output(); self.result_sender .send(Message::RowResult(wall.clone())) .or_else(|_| Err(println!("Failed to transmit result back"))); } self.results.insert(wall); } } else { for wall in self.results { wall.output() } self.result_sender.send(Message::GpuDone).unwrap(); // wait for second exit signal self.input.read(); return; } } } }