diff options
Diffstat (limited to 'src/solvers/gpu')
-rw-r--r-- | src/solvers/gpu/host.rs | 12 | ||||
-rw-r--r-- | src/solvers/gpu/manager.rs | 48 | ||||
-rw-r--r-- | src/solvers/gpu/mod.rs | 17 | ||||
-rw-r--r-- | src/solvers/gpu/output.rs | 83 |
4 files changed, 97 insertions, 63 deletions
diff --git a/src/solvers/gpu/host.rs b/src/solvers/gpu/host.rs index 6b79078..c53c42c 100644 --- a/src/solvers/gpu/host.rs +++ b/src/solvers/gpu/host.rs @@ -1,3 +1,4 @@ +use super::{HostMessage, Message, ResultMessage}; use ocl::{flags, Buffer, Context, Device, Kernel, Platform, Program, Queue}; use std::sync::mpsc::{Receiver, Sender}; @@ -17,7 +18,7 @@ pub struct Host { /// Workgroup size, set to 0 for max wg_size: usize, permutations: Buffer<u64>, - rec_queues: Vec<RequestBuffer>, + receiver: Receiver<Message>, walls: Vec<Vec<u32>>, } @@ -29,7 +30,7 @@ impl Host { w: u32, mut wg_size: usize, src: &str, - ) -> ocl::Result<Vec<Sender<Job>>> { + ) -> ocl::Result<(Sender<Message>, std::thread::JoinHandle<()>)> { let platform = ocl::Platform::default(); let device = ocl::Device::first(platform)?; let context = ocl::Context::builder() @@ -49,19 +50,12 @@ impl Host { .len(permutation_masks.len()) .build()?; - let mut senders = Vec::with_capacity((n - h + 1) as usize); - let mut receivers = Vec::with_capacity((n - h + 1) as usize); let max_wg_size = device.max_wg_size()?; if wg_size == 0 { wg_size = max_wg_size; } else if wg_size > max_wg_size { return Err(ocl::Error::from("invalid workgroup size")); } - for _ in 0..=(n - h) { - let (sx, rx) = std::sync::mpsc::channel(); - senders.push(sx); - receivers.push(RequestBuffer::new(wg_size, rx)); - } let solver = Self { platform, diff --git a/src/solvers/gpu/manager.rs b/src/solvers/gpu/manager.rs index 1dd6a4d..b3d88b8 100644 --- a/src/solvers/gpu/manager.rs +++ b/src/solvers/gpu/manager.rs @@ -1,10 +1,11 @@ -use std::sync::mpsc::{Receiver, Sender, channel}; +use super::{CheckRequest, Message}; +use std::sync::mpsc::{channel, Receiver, Sender}; use std::thread::JoinHandle; -use super::*; #[derive(Debug)] struct RequestBuffer { mask_buff: Vec<u64>, + row_buff: Vec<Vec<u32>>, pointer: usize, } @@ -12,15 +13,17 @@ impl RequestBuffer { pub fn new(size: usize) -> Self { RequestBuffer { mask_buff: vec![0; size], + row_buff: vec![Vec::new(); size], pointer: 0, } } - pub fn read(&mut self, request: CheckRequest) -> Option<&[u64]> { + pub fn read(&mut self, request: CheckRequest) -> Option<(&[u64], &[Vec<u32>])> { self.mask_buff[self.pointer] = request.bitmask; + self.row_buff[self.pointer] = request.rows; self.pointer += 1; if self.pointer == self.mask_buff.len() { self.pointer = 0; - return Some(self.mask_buff.as_ref()); + return Some((self.mask_buff.as_ref(), self.row_buff.as_ref())); } None } @@ -30,28 +33,28 @@ pub struct OclManager { job_id: u64, host_sender: Sender<Message>, output_sender: Sender<Message>, - reciever: Receiver<Message>, + receiver: Receiver<Message>, buffers: Vec<RequestBuffer>, - output_handle: JoinHandle<String>, - host_handle: JoinHandle<String>, + output_handle: JoinHandle<()>, + host_handle: JoinHandle<()>, } impl OclManager { pub fn launch_sevice( - permutations: &[&[u32]], + permutations: &[Vec<u32>], permutations_mask: &[u64], n: u32, // Workgroup size, set to 0 for max wg_size: u32, - ) -> (Sender<Message>, JoinHandle<String>) { + ) -> (Sender<Message>, JoinHandle<()>) { let (h, w) = crate::solvers::wall_stats(n); let src = include_str!("check.cl"); let (output_sender, output_handle) = - super::output::Output::launch_sevice(permutations, permutations_mask, n, h, w); + super::output::Output::launch_sevice(permutations, permutations_mask); let (host_sender, host_handle) = super::host::Host::launch_sevice(permutations_mask, n, h, w, wg_size as usize, src); - let (receiver, sender) = channel(); + let (sender, receiver) = channel(); let mut buffers = Vec::with_capacity((n - h + 1) as usize); for _ in 0..=(n - h) { @@ -59,27 +62,28 @@ impl OclManager { } let manager = Self { - 0, + job_id: 0, host_sender, output_sender, receiver, buffers, output_handle, host_handle, - } - (sender, - std::thread::Builder::new() - .name("GPU Manager Deamon".into()) - .spawn(move || { - manager.run(); - }) - .unwrap()) - + }; + ( + sender, + std::thread::Builder::new() + .name("GPU Manager Deamon".into()) + .spawn(move || { + manager.run(); + }) + .unwrap(), + ) } fn run(mut self) { loop { - match self.reciever.recv().expect("Channel to GPU Manager broke") { + match self.receiver.recv().expect("Channel to GPU Manager broke") { Message::CheckRequest(request) => { if let Some(buffer) = self.buffers[request.queue as usize].read(request) { self.host_sender diff --git a/src/solvers/gpu/mod.rs b/src/solvers/gpu/mod.rs index f9ab711..f147f31 100644 --- a/src/solvers/gpu/mod.rs +++ b/src/solvers/gpu/mod.rs @@ -11,6 +11,7 @@ pub enum Message { CheckRequest(CheckRequest), HostMessage(MaskMessage), OutputMessage(RowMessage), + ResultMessage(ResultMessage), Terminate, } @@ -18,14 +19,22 @@ pub struct ResultMessage { data: Vec<u64>, offset: usize, size: usize, + wg_size: usize, + id: u64, } impl ResultMessage { - fn new(data: Vec<u64>, offset: usize, size: usize) -> Self { - Self { data, offset, size } + fn new(data: Vec<u64>, offset: usize, size: usize, wg_size: usize, id: u64) -> Self { + Self { + data, + offset, + size, + wg_size, + id, + } } - fn valid_walls(&self, wg_size: usize) -> &[Vec<u32>] { - let mut result = vec![Vec::new(); wg_size]; + fn valid_walls(&self) -> &[Vec<u32>] { + let mut result = vec![Vec::new(); self.wg_size]; for (j, r) in self.data.iter().enumerate() { for b in 0..64 { if r & (1 << b) != 0 { diff --git a/src/solvers/gpu/output.rs b/src/solvers/gpu/output.rs index 58a4aa5..a716340 100644 --- a/src/solvers/gpu/output.rs +++ b/src/solvers/gpu/output.rs @@ -1,13 +1,12 @@ -use super::Message; -use std::collections::{HashSet, HashMap}; +use super::{Message, ResultMessage}; +use std::collections::{HashMap, HashSet}; use std::sync::mpsc::{channel, Receiver, Sender}; use std::thread::JoinHandle; struct InBuffer { receiver: Receiver<Message>, - row_requests: HashMap<u64, Vec<u32>>, - results_requests: HashMap<u64, Vec<u64>>, - + row_requests: HashMap<u64, Vec<Vec<u32>>>, + results_requests: HashMap<u64, ResultMessage>, } impl InBuffer { @@ -18,23 +17,51 @@ impl InBuffer { results_requests: HashMap::new(), } } - fn read(&mut self) -> Option<Result> { + fn read(&mut self) -> Option<Vec<RowResult>> { loop { - match self.receiver.recv() { - Message::OutputMessage((id, ResultMessage)) => { - if Some(result) = self.results_requests.get(id) { - Some(RowResult::new() + match self + .receiver + .recv() + .expect("Channel to Output Daemon broke") + { + Message::ResultMessage(results) => { + if let Some(result_walls) = self.row_requests.get(&results.id) { + return Some(Self::calc_results(results.valid_walls(), result_walls)); + } else { + self.results_requests.insert(results.id, results); + } + } + Message::OutputMessage((id, output)) => { + if let Some(results) = self.results_requests.get(&id) { + return Some(Self::calc_results(results.valid_walls(), output.as_ref())); + } else { + self.row_requests.insert(id, output); } - else { - self.row_requests.insert(id, walls);} + } + Message::Terminate => { + return None; + } + _ => { + println!("Invalid MessageType"); + } } } } - + fn calc_results(res_req: &[Vec<u32>], row_req: &[Vec<u32>]) -> Vec<RowResult> { + let out = Vec::new(); + for (rows, perms) in row_req.iter().zip(res_req.iter()) { + for p in perms { + let new = rows.clone(); + new.push(*p); + out.push(RowResult::new(new)); + } + } + out + } } #[derive(PartialEq, Eq, Hash)] -struct RowResult { +pub struct RowResult { rows: Vec<u32>, } @@ -50,26 +77,25 @@ impl RowResult { pub struct Output { input: InBuffer, - receiver: Receiver<Message>, permutations: Vec<Vec<u32>>, permutations_mask: Vec<u64>, results: HashSet<RowResult>, } impl Output { - fn launch_sevice( + pub fn launch_sevice( permutations: &[Vec<u32>], - permutation_masks: &[u64], - ) -> (Sender<Message>, JoinHandle<String>) { + permutations_mask: &[u64], + ) -> (Sender<Message>, JoinHandle<()>) { let (sender, receiver) = channel(); let input = InBuffer::new(receiver); let output = Self { input, permutations: permutations.into(), - permutation_masks: permutation_masks.into(), - HashSet::new(), - } + permutations_mask: permutations_mask.into(), + results: HashSet::new(), + }; ( sender, std::thread::Builder::new() @@ -83,13 +109,14 @@ impl Output { fn run(mut self) { loop { - match self.receiver.recv() { - Message::OutputMessage((id, ResultMessage)) => { - if Some(result) = self.results_requests.get(id) { - Some(RowResult::new() - } - else { - self.row_requests.insert(id, walls);} + if let Some(walls) = self.input.read() { + for wall in walls { + self.results.insert(wall); + } + } else { + for wall in self.results { + wall.output() + } } } } |