diff options
Diffstat (limited to 'src/solvers/gpu/manager.rs')
-rw-r--r-- | src/solvers/gpu/manager.rs | 129 |
1 files changed, 70 insertions, 59 deletions
diff --git a/src/solvers/gpu/manager.rs b/src/solvers/gpu/manager.rs index 42af314..6eaaaaa 100644 --- a/src/solvers/gpu/manager.rs +++ b/src/solvers/gpu/manager.rs @@ -1,6 +1,7 @@ -use super::{CheckRequest, Message, RowResult}; +use super::{CheckRequest, GpuError, Message}; use std::sync::mpsc::{channel, Receiver, Sender}; use std::thread::JoinHandle; +type Buffer<'a> = (&'a [u64], &'a [Vec<u32>]); #[derive(Debug)] struct RequestBuffer { @@ -17,18 +18,19 @@ impl RequestBuffer { pointer: 0, } } - pub fn read(&mut self, request: CheckRequest) -> Option<(&[u64], &[Vec<u32>])> { + pub fn read(&mut self, request: CheckRequest) -> Result<Option<Buffer>, GpuError> { self.mask_buff[self.pointer] = request.bitmask; self.row_buff[self.pointer] = request.rows; self.pointer += 1; if self.pointer == self.mask_buff.len() { self.pointer = 0; - return Some((self.mask_buff.as_ref(), self.row_buff.as_ref())); + return Ok(Some((self.mask_buff.as_ref(), self.row_buff.as_ref()))); } - None + Ok(None) } - fn flush(&mut self) { - while self.read(CheckRequest::new(vec![], 0, 0)).is_none() {} + fn flush(&mut self) -> Result<(), GpuError> { + while self.read(CheckRequest::new(vec![], 0, 0))?.is_none() {} + Ok(()) } } @@ -50,13 +52,13 @@ impl OclManager { // Workgroup size, set to 0 for max mut wg_size: usize, result_output: Sender<Message>, - ) -> (Sender<Message>, JoinHandle<()>) { + ) -> Result<(Sender<Message>, JoinHandle<()>), GpuError> { let (h, w) = crate::solvers::wall_stats(n); let src = include_str!("check.cl"); let platform = ocl::Platform::default(); - let device = ocl::Device::first(platform).expect("failed to create opencl device"); - let max_wg_size = device.max_wg_size().expect("failed to query max_wg_size"); + let device = ocl::Device::first(platform)?; + let max_wg_size = device.max_wg_size()?; if wg_size == 0 { wg_size = max_wg_size; } else if wg_size > max_wg_size { @@ -65,7 +67,7 @@ impl OclManager { let (output_sender, output_handle) = super::output::Output::launch_sevice(permutations, permutations_mask, result_output); - let (host_sender, host_handle) = super::host::Host::launch_sevice( + if let Ok((host_sender, host_handle)) = super::host::Host::launch_sevice( permutations_mask, n, h, @@ -73,76 +75,85 @@ impl OclManager { wg_size, src, output_sender.clone(), - ) - .unwrap(); + ) { + let (sender, receiver) = channel(); - let (sender, receiver) = channel(); + println!("wg {}", wg_size); + let mut buffers = Vec::with_capacity((n - h + 1) as usize); + for _ in 0..=(n - h) { + buffers.push(RequestBuffer::new(wg_size as usize)); + } - println!("wg {}", wg_size); - let mut buffers = Vec::with_capacity((n - h + 1) as usize); - for _ in 0..=(n - h) { - buffers.push(RequestBuffer::new(wg_size as usize)); + let manager = Self { + job_id: 0, + host_sender, + output_sender, + receiver, + buffers, + output_handle, + host_handle: Some(host_handle), + }; + Ok(( + sender, + std::thread::Builder::new() + .name("GPU Manager Deamon".into()) + .spawn(move || { + if let Err(err) = manager.run() { + println!("{}", err); + } + }) + .unwrap(), + )) + } else { + Err(GpuError::from( + "Failed to launch the opnecl thread".to_string(), + )) } - - let manager = Self { - job_id: 0, - host_sender, - output_sender, - receiver, - buffers, - output_handle, - host_handle: Some(host_handle), - }; - ( - sender, - std::thread::Builder::new() - .name("GPU Manager Deamon".into()) - .spawn(move || { - manager.run(); - }) - .unwrap(), - ) } - fn run(mut self) { + fn run(mut self) -> Result<(), GpuError> { loop { - match self.receiver.recv().expect("Channel to GPU Manager broke") { + match self.receiver.recv()? { Message::CheckRequest(request) => { let queue = request.queue; //println!("num: {:?} bit {:b}", request.rows, request.bitmask); - if let Some(buffer) = self.buffers[queue as usize].read(request) { - self.host_sender - .send(Message::HostMessage((self.job_id, queue, buffer.0.into()))) - .unwrap(); + if let Some(buffer) = self.buffers[queue as usize].read(request)? { + self.host_sender.send(Message::HostMessage(( + self.job_id, + queue, + buffer.0.into(), + )))?; self.output_sender - .send(Message::OutputMessage((self.job_id, buffer.1.into()))) - .unwrap(); + .send(Message::OutputMessage((self.job_id, buffer.1.into())))?; self.job_id += 1; } } Message::CpuDone => { for (i, b) in self.buffers.iter_mut().enumerate() { - b.flush(); - self.host_sender - .send(Message::HostMessage(( - self.job_id, - i as u32, - b.mask_buff.clone(), - ))) - .unwrap(); + b.flush()?; + self.host_sender.send(Message::HostMessage(( + self.job_id, + i as u32, + b.mask_buff.clone(), + )))?; self.output_sender - .send(Message::OutputMessage((self.job_id, b.row_buff.clone()))) - .unwrap(); + .send(Message::OutputMessage((self.job_id, b.row_buff.clone())))?; self.job_id += 1; } println!("flushing buffers"); - self.host_sender.send(Message::CpuDone); - self.host_handle.take().unwrap().join(); + self.host_sender.send(Message::CpuDone)?; + self.host_handle + .take() + .unwrap() + .join() + .expect("failed to join host thread"); } Message::Terminate => { - self.output_sender.send(Message::Terminate); - self.output_handle.join(); - return; + self.output_sender.send(Message::Terminate)?; + self.output_handle + .join() + .expect("failed to join ouput thread"); + return Ok(()); } _ => println!("Invalid MessageType"), } |