summaryrefslogtreecommitdiff
path: root/src/solvers/gpu/manager.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/solvers/gpu/manager.rs')
-rw-r--r--src/solvers/gpu/manager.rs129
1 files changed, 70 insertions, 59 deletions
diff --git a/src/solvers/gpu/manager.rs b/src/solvers/gpu/manager.rs
index 42af314..6eaaaaa 100644
--- a/src/solvers/gpu/manager.rs
+++ b/src/solvers/gpu/manager.rs
@@ -1,6 +1,7 @@
-use super::{CheckRequest, Message, RowResult};
+use super::{CheckRequest, GpuError, Message};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::thread::JoinHandle;
+type Buffer<'a> = (&'a [u64], &'a [Vec<u32>]);
#[derive(Debug)]
struct RequestBuffer {
@@ -17,18 +18,19 @@ impl RequestBuffer {
pointer: 0,
}
}
- pub fn read(&mut self, request: CheckRequest) -> Option<(&[u64], &[Vec<u32>])> {
+ pub fn read(&mut self, request: CheckRequest) -> Result<Option<Buffer>, GpuError> {
self.mask_buff[self.pointer] = request.bitmask;
self.row_buff[self.pointer] = request.rows;
self.pointer += 1;
if self.pointer == self.mask_buff.len() {
self.pointer = 0;
- return Some((self.mask_buff.as_ref(), self.row_buff.as_ref()));
+ return Ok(Some((self.mask_buff.as_ref(), self.row_buff.as_ref())));
}
- None
+ Ok(None)
}
- fn flush(&mut self) {
- while self.read(CheckRequest::new(vec![], 0, 0)).is_none() {}
+ fn flush(&mut self) -> Result<(), GpuError> {
+ while self.read(CheckRequest::new(vec![], 0, 0))?.is_none() {}
+ Ok(())
}
}
@@ -50,13 +52,13 @@ impl OclManager {
// Workgroup size, set to 0 for max
mut wg_size: usize,
result_output: Sender<Message>,
- ) -> (Sender<Message>, JoinHandle<()>) {
+ ) -> Result<(Sender<Message>, JoinHandle<()>), GpuError> {
let (h, w) = crate::solvers::wall_stats(n);
let src = include_str!("check.cl");
let platform = ocl::Platform::default();
- let device = ocl::Device::first(platform).expect("failed to create opencl device");
- let max_wg_size = device.max_wg_size().expect("failed to query max_wg_size");
+ let device = ocl::Device::first(platform)?;
+ let max_wg_size = device.max_wg_size()?;
if wg_size == 0 {
wg_size = max_wg_size;
} else if wg_size > max_wg_size {
@@ -65,7 +67,7 @@ impl OclManager {
let (output_sender, output_handle) =
super::output::Output::launch_sevice(permutations, permutations_mask, result_output);
- let (host_sender, host_handle) = super::host::Host::launch_sevice(
+ if let Ok((host_sender, host_handle)) = super::host::Host::launch_sevice(
permutations_mask,
n,
h,
@@ -73,76 +75,85 @@ impl OclManager {
wg_size,
src,
output_sender.clone(),
- )
- .unwrap();
+ ) {
+ let (sender, receiver) = channel();
- let (sender, receiver) = channel();
+ println!("wg {}", wg_size);
+ let mut buffers = Vec::with_capacity((n - h + 1) as usize);
+ for _ in 0..=(n - h) {
+ buffers.push(RequestBuffer::new(wg_size as usize));
+ }
- println!("wg {}", wg_size);
- let mut buffers = Vec::with_capacity((n - h + 1) as usize);
- for _ in 0..=(n - h) {
- buffers.push(RequestBuffer::new(wg_size as usize));
+ let manager = Self {
+ job_id: 0,
+ host_sender,
+ output_sender,
+ receiver,
+ buffers,
+ output_handle,
+ host_handle: Some(host_handle),
+ };
+ Ok((
+ sender,
+ std::thread::Builder::new()
+ .name("GPU Manager Deamon".into())
+ .spawn(move || {
+ if let Err(err) = manager.run() {
+ println!("{}", err);
+ }
+ })
+ .unwrap(),
+ ))
+ } else {
+ Err(GpuError::from(
+ "Failed to launch the opnecl thread".to_string(),
+ ))
}
-
- let manager = Self {
- job_id: 0,
- host_sender,
- output_sender,
- receiver,
- buffers,
- output_handle,
- host_handle: Some(host_handle),
- };
- (
- sender,
- std::thread::Builder::new()
- .name("GPU Manager Deamon".into())
- .spawn(move || {
- manager.run();
- })
- .unwrap(),
- )
}
- fn run(mut self) {
+ fn run(mut self) -> Result<(), GpuError> {
loop {
- match self.receiver.recv().expect("Channel to GPU Manager broke") {
+ match self.receiver.recv()? {
Message::CheckRequest(request) => {
let queue = request.queue;
//println!("num: {:?} bit {:b}", request.rows, request.bitmask);
- if let Some(buffer) = self.buffers[queue as usize].read(request) {
- self.host_sender
- .send(Message::HostMessage((self.job_id, queue, buffer.0.into())))
- .unwrap();
+ if let Some(buffer) = self.buffers[queue as usize].read(request)? {
+ self.host_sender.send(Message::HostMessage((
+ self.job_id,
+ queue,
+ buffer.0.into(),
+ )))?;
self.output_sender
- .send(Message::OutputMessage((self.job_id, buffer.1.into())))
- .unwrap();
+ .send(Message::OutputMessage((self.job_id, buffer.1.into())))?;
self.job_id += 1;
}
}
Message::CpuDone => {
for (i, b) in self.buffers.iter_mut().enumerate() {
- b.flush();
- self.host_sender
- .send(Message::HostMessage((
- self.job_id,
- i as u32,
- b.mask_buff.clone(),
- )))
- .unwrap();
+ b.flush()?;
+ self.host_sender.send(Message::HostMessage((
+ self.job_id,
+ i as u32,
+ b.mask_buff.clone(),
+ )))?;
self.output_sender
- .send(Message::OutputMessage((self.job_id, b.row_buff.clone())))
- .unwrap();
+ .send(Message::OutputMessage((self.job_id, b.row_buff.clone())))?;
self.job_id += 1;
}
println!("flushing buffers");
- self.host_sender.send(Message::CpuDone);
- self.host_handle.take().unwrap().join();
+ self.host_sender.send(Message::CpuDone)?;
+ self.host_handle
+ .take()
+ .unwrap()
+ .join()
+ .expect("failed to join host thread");
}
Message::Terminate => {
- self.output_sender.send(Message::Terminate);
- self.output_handle.join();
- return;
+ self.output_sender.send(Message::Terminate)?;
+ self.output_handle
+ .join()
+ .expect("failed to join ouput thread");
+ return Ok(());
}
_ => println!("Invalid MessageType"),
}