summaryrefslogtreecommitdiff
path: root/src/solvers/gpu/host.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/solvers/gpu/host.rs')
-rw-r--r--src/solvers/gpu/host.rs235
1 files changed, 235 insertions, 0 deletions
diff --git a/src/solvers/gpu/host.rs b/src/solvers/gpu/host.rs
new file mode 100644
index 0000000..6b79078
--- /dev/null
+++ b/src/solvers/gpu/host.rs
@@ -0,0 +1,235 @@
+use ocl::{flags, Buffer, Context, Device, Kernel, Platform, Program, Queue};
+use std::sync::mpsc::{Receiver, Sender};
+
+#[derive(Debug)]
+pub struct Host {
+ #[allow(unused)]
+ platform: Platform,
+ #[allow(unused)]
+ device: Device,
+ #[allow(unused)]
+ context: Context,
+ program: Program,
+ queue: Queue,
+ n: u32,
+ h: u32,
+ w: u32,
+ /// Workgroup size, set to 0 for max
+ wg_size: usize,
+ permutations: Buffer<u64>,
+ rec_queues: Vec<RequestBuffer>,
+ walls: Vec<Vec<u32>>,
+}
+
+impl Host {
+ pub fn launch_sevice(
+ permutation_masks: &[u64],
+ n: u32,
+ h: u32,
+ w: u32,
+ mut wg_size: usize,
+ src: &str,
+ ) -> ocl::Result<Vec<Sender<Job>>> {
+ let platform = ocl::Platform::default();
+ let device = ocl::Device::first(platform)?;
+ let context = ocl::Context::builder()
+ .platform(platform)
+ .devices(device.clone())
+ .build()?;
+ let queue = ocl::Queue::new(&context, device, None)?;
+
+ let program = Program::builder()
+ .devices(device)
+ .src(src)
+ .build(&context)?;
+ let buffer = ocl::Buffer::builder()
+ .queue(queue.clone())
+ .flags(flags::MEM_READ_WRITE)
+ .copy_host_slice(permutation_masks)
+ .len(permutation_masks.len())
+ .build()?;
+
+ let mut senders = Vec::with_capacity((n - h + 1) as usize);
+ let mut receivers = Vec::with_capacity((n - h + 1) as usize);
+ let max_wg_size = device.max_wg_size()?;
+ if wg_size == 0 {
+ wg_size = max_wg_size;
+ } else if wg_size > max_wg_size {
+ return Err(ocl::Error::from("invalid workgroup size"));
+ }
+ for _ in 0..=(n - h) {
+ let (sx, rx) = std::sync::mpsc::channel();
+ senders.push(sx);
+ receivers.push(RequestBuffer::new(wg_size, rx));
+ }
+
+ let solver = Self {
+ platform,
+ device,
+ context,
+ program,
+ queue,
+ n,
+ h,
+ w,
+ wg_size,
+ permutations: buffer,
+ rec_queues: receivers,
+ walls: Vec::new(),
+ };
+ std::thread::Builder::new()
+ .name("GPU Deamon".into())
+ .spawn(move || {
+ solver.run();
+ })
+ .unwrap();
+ println!("started gpu thread");
+ Ok(senders)
+ }
+
+ fn get_dim(&self, queue: usize) -> usize {
+ let chunk = self.permutations.len() / self.n as usize;
+ let dim = (queue + 1) * chunk;
+ (dim + self.wg_size - 1) / self.wg_size * self.wg_size
+ }
+ fn get_off(&self, queue: usize) -> u64 {
+ let chunk = self.permutations.len() / self.n as usize;
+ let off = self.permutations.len() - chunk - self.get_dim(queue);
+ if off > isize::max_value() as usize {
+ panic!("workgroup size to big, offset underflow")
+ }
+ off as u64
+ }
+ fn get_res(&self, queue: usize) -> usize {
+ let dim = self.get_dim(queue);
+ dim * self.get_res_save_dim()
+ }
+ fn get_res_save_dim(&self) -> usize {
+ (self.wg_size + 63) / 64
+ }
+
+ fn run(mut self) -> ! {
+ let queues = self.rec_queues.len();
+ let mut instruction_buffer = Vec::with_capacity((self.n - self.h) as usize);
+ let mut result_buffer = Vec::with_capacity((self.n - self.h) as usize);
+
+ for i in 0..queues {
+ let buffer: Buffer<u64> = Buffer::builder()
+ .queue(self.queue.clone())
+ .len(self.wg_size)
+ .flags(flags::MEM_READ_WRITE)
+ .build()
+ .unwrap();
+
+ instruction_buffer.push(buffer);
+ let results: Buffer<u64> = Buffer::builder()
+ .queue(self.queue.clone())
+ .len(self.get_res(i))
+ .flags(flags::MEM_READ_WRITE)
+ .build()
+ .unwrap();
+ result_buffer.push(results);
+ }
+ println!("finished gpu setup");
+ for i in (0..self.rec_queues.len()).cycle() {
+ if let Some(buffer) = self.rec_queues[i].read() {
+ instruction_buffer[i].write(buffer).enq().unwrap();
+ let dim = self.get_dim(i);
+
+ //println!("dim: {}", dim);
+ //println!("off: {}", self.get_off(i));
+ //println!("result size: {}", self.get_res_save_dim());
+ let kernel = Kernel::builder()
+ .program(&self.program)
+ .name("check")
+ .queue(self.queue.clone())
+ .global_work_size(dim)
+ .arg(&self.permutations)
+ .arg(&result_buffer[i])
+ .arg(&instruction_buffer[i])
+ .arg_local::<u64>(self.wg_size)
+ .arg(self.n)
+ .arg(self.w)
+ .arg(self.get_off(i))
+ .build()
+ .unwrap();
+
+ unsafe {
+ kernel
+ .cmd()
+ .queue(&self.queue)
+ .global_work_offset(kernel.default_global_work_offset())
+ .global_work_size(dim)
+ .local_work_size(self.wg_size)
+ .enq()
+ .unwrap();
+ }
+
+ // (5) Read results from the device into a vector (`::block` not shown):
+ let mut result = vec![0u64; self.get_res(i)];
+ result_buffer[i]
+ .cmd()
+ .queue(&self.queue)
+ .offset(0)
+ .read(&mut result)
+ .enq()
+ .unwrap();
+ for (j, r) in result.iter().enumerate() {
+ if j == 0 {
+ continue;
+ }
+ for b in 0..64 {
+ if r & (1 << b) != 0 {
+ let permutation =
+ j / self.get_res_save_dim() + self.get_off(i) as usize;
+ let instruction = (j % self.get_res_save_dim()) * 64 + b;
+ let mut wall = self.rec_queues[i].get_rows()[instruction].clone();
+ wall.push(permutation as u32);
+ println!("{:?}", wall);
+ self.walls.push(wall);
+ }
+ }
+ }
+ }
+ }
+ panic!();
+ }
+}
+/*
+pub fn check(permutations: &[u64], w: u32, n: u32, mask: u64, offset: usize) -> ocl::Result<()> {
+ //println!("read src!");
+ let src = std::fs::read_to_string("src/solvers/check.cl").expect("failed to open kernel file");
+
+ //println!("created queue!");
+ println!("offset: {}", offset);
+ println!("length: {}", permutations.len() - offset);
+ let pro_que = ocl::ProQue::builder()
+ .src(src)
+ .dims(permutations.len() - offset)
+ .build()?;
+
+ let results = pro_que.create_buffer::<i32>()?;
+ let kernel = pro_que
+ .kernel_builder("check")
+ .arg(get_buffer())
+ .arg(&results)
+ .arg(mask)
+ .arg(n)
+ .arg(w)
+ .arg(offset as u64)
+ //.global_work_offset(offset)
+ .build()?;
+
+ //println!("starting calculation");
+ unsafe {
+ kernel.enq()?;
+ }
+
+ let mut vec = vec![0; results.len()];
+ results.read(&mut vec).enq()?;
+
+ if vec.iter().any(|x| *x != 0) {
+ println!("The resuts are now '{:?}'!", vec);
+ }
+ Ok(())
+}*/