summaryrefslogtreecommitdiff
path: root/src/solvers/gpu
diff options
context:
space:
mode:
authorDennis Kobert <dennis@kobert.dev>2020-01-12 09:51:22 +0100
committerDennis Kobert <dennis@kobert.dev>2020-01-12 09:51:22 +0100
commit8708a172ebe59d3189b8b9d756abd9da8dc509a3 (patch)
treed2fb58525e4773a509016850cad55c27540de736 /src/solvers/gpu
parent3a0d646ade02a6ca006a0d8cf6c0f60a1ece8272 (diff)
Allow to use teh iterator for results
Diffstat (limited to 'src/solvers/gpu')
-rw-r--r--src/solvers/gpu/host.rs7
-rw-r--r--src/solvers/gpu/manager.rs17
-rw-r--r--src/solvers/gpu/mod.rs22
-rw-r--r--src/solvers/gpu/output.rs34
4 files changed, 53 insertions, 27 deletions
diff --git a/src/solvers/gpu/host.rs b/src/solvers/gpu/host.rs
index e354263..d67138e 100644
--- a/src/solvers/gpu/host.rs
+++ b/src/solvers/gpu/host.rs
@@ -123,8 +123,11 @@ impl Host {
println!("finished gpu setup");
loop {
match self.receiver.recv().expect("Channel to Host broke") {
+ Message::CpuDone => {
+ self.output_sender.send(Message::CpuDone);
+ return;
+ }
Message::Terminate => {
- self.output_sender.send(Message::Terminate);
return;
}
Message::HostMessage((id, i, buffer)) => {
@@ -184,7 +187,7 @@ impl Host {
)))
.unwrap();
}
- _ => println!("Invalid MessageType"),
+ m => println!("Invalid MessageType {:?} recived by host", m),
}
}
}
diff --git a/src/solvers/gpu/manager.rs b/src/solvers/gpu/manager.rs
index a2253aa..def1f35 100644
--- a/src/solvers/gpu/manager.rs
+++ b/src/solvers/gpu/manager.rs
@@ -1,4 +1,4 @@
-use super::{CheckRequest, Message};
+use super::{CheckRequest, Message, RowResult};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::thread::JoinHandle;
@@ -36,7 +36,7 @@ pub struct OclManager {
receiver: Receiver<Message>,
buffers: Vec<RequestBuffer>,
output_handle: JoinHandle<()>,
- host_handle: JoinHandle<()>,
+ host_handle: Option<JoinHandle<()>>,
}
impl OclManager {
@@ -46,6 +46,7 @@ impl OclManager {
n: u32,
// Workgroup size, set to 0 for max
mut wg_size: usize,
+ result_output: Sender<Message>,
) -> (Sender<Message>, JoinHandle<()>) {
let (h, w) = crate::solvers::wall_stats(n);
let src = include_str!("check.cl");
@@ -60,7 +61,7 @@ impl OclManager {
}
let (output_sender, output_handle) =
- super::output::Output::launch_sevice(permutations, permutations_mask);
+ super::output::Output::launch_sevice(permutations, permutations_mask, result_output);
let (host_sender, host_handle) = super::host::Host::launch_sevice(
permutations_mask,
n,
@@ -87,7 +88,7 @@ impl OclManager {
receiver,
buffers,
output_handle,
- host_handle,
+ host_handle: Some(host_handle),
};
(
sender,
@@ -115,10 +116,12 @@ impl OclManager {
self.job_id += 1;
}
}
- Message::Terminate => {
+ Message::CpuDone => {
//TODO panic!("flush buffers");
- self.host_sender.send(Message::Terminate);
- self.host_handle.join();
+ self.host_sender.send(Message::CpuDone);
+ self.host_handle.take().unwrap().join();
+ }
+ Message::Terminate => {
self.output_sender.send(Message::Terminate);
self.output_handle.join();
return;
diff --git a/src/solvers/gpu/mod.rs b/src/solvers/gpu/mod.rs
index 2c7f69d..e89f033 100644
--- a/src/solvers/gpu/mod.rs
+++ b/src/solvers/gpu/mod.rs
@@ -7,14 +7,19 @@ pub use manager::*;
type MaskMessage = (u64, u32, Vec<u64>);
type RowMessage = (u64, Vec<Vec<u32>>);
+#[derive(Debug)]
pub enum Message {
CheckRequest(CheckRequest),
HostMessage(MaskMessage),
OutputMessage(RowMessage),
ResultMessage(ResultMessage),
+ RowResult(RowResult),
Terminate,
+ CpuDone,
+ GpuDone,
}
+#[derive(Debug)]
pub struct ResultMessage {
data: Vec<u64>,
offset: usize,
@@ -48,6 +53,7 @@ impl ResultMessage {
}
}
+#[derive(Debug)]
pub struct CheckRequest {
rows: Vec<u32>,
bitmask: u64,
@@ -63,3 +69,19 @@ impl CheckRequest {
}
}
}
+
+#[derive(Clone, PartialEq, Eq, Hash, Debug)]
+pub struct RowResult {
+ rows: Vec<u32>,
+}
+
+impl RowResult {
+ fn new(mut rows: Vec<u32>) -> Self {
+ rows.push(0);
+ rows.sort();
+ Self { rows }
+ }
+ fn output(&self) {
+ println!("{:?}", self.rows);
+ }
+}
diff --git a/src/solvers/gpu/output.rs b/src/solvers/gpu/output.rs
index 43e0d98..6bf7bcb 100644
--- a/src/solvers/gpu/output.rs
+++ b/src/solvers/gpu/output.rs
@@ -1,4 +1,4 @@
-use super::{Message, ResultMessage};
+use super::{Message, ResultMessage, RowResult};
use std::collections::{HashMap, HashSet};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::thread::JoinHandle;
@@ -44,6 +44,9 @@ impl InBuffer {
self.row_requests.insert(id, output);
}
}
+ Message::CpuDone => {
+ return None;
+ }
Message::Terminate => {
return None;
}
@@ -66,33 +69,19 @@ impl InBuffer {
}
}
-#[derive(PartialEq, Eq, Hash)]
-pub struct RowResult {
- rows: Vec<u32>,
-}
-
-impl RowResult {
- fn new(mut rows: Vec<u32>) -> Self {
- rows.push(0);
- rows.sort();
- Self { rows }
- }
- fn output(&self) {
- println!("{:?}", self.rows);
- }
-}
-
pub struct Output {
input: InBuffer,
permutations: Vec<Vec<u32>>,
permutations_mask: Vec<u64>,
results: HashSet<RowResult>,
+ result_sender: Sender<Message>,
}
impl Output {
pub fn launch_sevice(
permutations: &[Vec<u32>],
permutations_mask: &[u64],
+ result_sender: Sender<Message>,
) -> (Sender<Message>, JoinHandle<()>) {
let (sender, receiver) = channel();
let input = InBuffer::new(receiver);
@@ -102,11 +91,12 @@ impl Output {
permutations: permutations.into(),
permutations_mask: permutations_mask.into(),
results: HashSet::new(),
+ result_sender,
};
(
sender,
std::thread::Builder::new()
- .name("GPU Manager Deamon".into())
+ .name("GPU Output Deamon".into())
.spawn(move || {
output.run();
})
@@ -118,12 +108,20 @@ impl Output {
loop {
if let Some(walls) = self.input.read() {
for wall in walls {
+ if !self.results.contains(&wall) {
+ self.result_sender
+ .send(Message::RowResult(wall.clone()))
+ .or_else(|_| Err(println!("Failed to transmit result back")));
+ }
self.results.insert(wall);
}
} else {
for wall in self.results {
wall.output()
}
+ self.result_sender.send(Message::GpuDone).unwrap();
+ // wait for second exit signal
+ self.input.read();
return;
}
}