From 29bffc6f6c794fee886904ad3960c4cb770deb11 Mon Sep 17 00:00:00 2001 From: Dennis Kobert Date: Sun, 12 Jan 2020 05:18:56 +0100 Subject: Fix Bugs --- src/solvers/gpu/output.rs | 83 +++++++++++++++++++++++++++++++---------------- 1 file changed, 55 insertions(+), 28 deletions(-) (limited to 'src/solvers/gpu/output.rs') diff --git a/src/solvers/gpu/output.rs b/src/solvers/gpu/output.rs index 58a4aa5..a716340 100644 --- a/src/solvers/gpu/output.rs +++ b/src/solvers/gpu/output.rs @@ -1,13 +1,12 @@ -use super::Message; -use std::collections::{HashSet, HashMap}; +use super::{Message, ResultMessage}; +use std::collections::{HashMap, HashSet}; use std::sync::mpsc::{channel, Receiver, Sender}; use std::thread::JoinHandle; struct InBuffer { receiver: Receiver, - row_requests: HashMap>, - results_requests: HashMap>, - + row_requests: HashMap>>, + results_requests: HashMap, } impl InBuffer { @@ -18,23 +17,51 @@ impl InBuffer { results_requests: HashMap::new(), } } - fn read(&mut self) -> Option { + fn read(&mut self) -> Option> { loop { - match self.receiver.recv() { - Message::OutputMessage((id, ResultMessage)) => { - if Some(result) = self.results_requests.get(id) { - Some(RowResult::new() + match self + .receiver + .recv() + .expect("Channel to Output Daemon broke") + { + Message::ResultMessage(results) => { + if let Some(result_walls) = self.row_requests.get(&results.id) { + return Some(Self::calc_results(results.valid_walls(), result_walls)); + } else { + self.results_requests.insert(results.id, results); + } + } + Message::OutputMessage((id, output)) => { + if let Some(results) = self.results_requests.get(&id) { + return Some(Self::calc_results(results.valid_walls(), output.as_ref())); + } else { + self.row_requests.insert(id, output); } - else { - self.row_requests.insert(id, walls);} + } + Message::Terminate => { + return None; + } + _ => { + println!("Invalid MessageType"); + } } } } - + fn calc_results(res_req: &[Vec], row_req: &[Vec]) -> Vec { + let out = Vec::new(); + for (rows, perms) in row_req.iter().zip(res_req.iter()) { + for p in perms { + let new = rows.clone(); + new.push(*p); + out.push(RowResult::new(new)); + } + } + out + } } #[derive(PartialEq, Eq, Hash)] -struct RowResult { +pub struct RowResult { rows: Vec, } @@ -50,26 +77,25 @@ impl RowResult { pub struct Output { input: InBuffer, - receiver: Receiver, permutations: Vec>, permutations_mask: Vec, results: HashSet, } impl Output { - fn launch_sevice( + pub fn launch_sevice( permutations: &[Vec], - permutation_masks: &[u64], - ) -> (Sender, JoinHandle) { + permutations_mask: &[u64], + ) -> (Sender, JoinHandle<()>) { let (sender, receiver) = channel(); let input = InBuffer::new(receiver); let output = Self { input, permutations: permutations.into(), - permutation_masks: permutation_masks.into(), - HashSet::new(), - } + permutations_mask: permutations_mask.into(), + results: HashSet::new(), + }; ( sender, std::thread::Builder::new() @@ -83,13 +109,14 @@ impl Output { fn run(mut self) { loop { - match self.receiver.recv() { - Message::OutputMessage((id, ResultMessage)) => { - if Some(result) = self.results_requests.get(id) { - Some(RowResult::new() - } - else { - self.row_requests.insert(id, walls);} + if let Some(walls) = self.input.read() { + for wall in walls { + self.results.insert(wall); + } + } else { + for wall in self.results { + wall.output() + } } } } -- cgit v1.2.3-54-g00ecf