use super::{GpuError, Message, ResultMessage}; use ocl::{flags, Buffer, Context, Device, Kernel, Platform, Program, Queue}; use std::sync::mpsc::{channel, Receiver, Sender}; #[derive(Debug)] pub struct Host { #[allow(unused)] platform: Platform, #[allow(unused)] device: Device, #[allow(unused)] context: Context, program: Program, queue: Queue, n: u32, h: u32, w: u32, /// Workgroup size, set to 0 for max wg_size: usize, permutations: Buffer, receiver: Receiver, output_sender: Sender, } impl Host { pub fn launch_sevice( permutation_masks: &[u64], n: u32, h: u32, w: u32, wg_size: usize, src: &str, output_sender: Sender, ) -> ocl::Result<(Sender, std::thread::JoinHandle<()>)> { let platform = ocl::Platform::default(); let device = ocl::Device::first(platform)?; let context = ocl::Context::builder() .platform(platform) .devices(device.clone()) .build()?; let queue = ocl::Queue::new(&context, device, None)?; let program = Program::builder() .devices(device) .src(src) .build(&context)?; let buffer = ocl::Buffer::builder() .queue(queue.clone()) .flags(flags::MEM_READ_WRITE) .copy_host_slice(permutation_masks) .len(permutation_masks.len()) .build()?; let (sender, receiver) = channel(); let solver = Self { platform, device, context, program, queue, n, h, w, wg_size, permutations: buffer, receiver, output_sender, }; let handle = std::thread::Builder::new() .name("GPU Host Deamon".into()) .spawn(move || { if let Err(err) = solver.run() { println!("{}", err); } }) .unwrap(); println!("started gpu thread"); Ok((sender, handle)) } fn get_dim(&self, queue: usize) -> usize { let chunk = self.permutations.len() / self.n as usize; let dim = (queue + 1) * chunk; (dim + self.wg_size - 1) / self.wg_size * self.wg_size } fn get_off(&self, queue: usize) -> usize { let chunk = self.permutations.len() / self.n as usize; if self.permutations.len() < chunk + self.get_dim(queue) { panic!("workgroup size too big; offset underflow") } self.permutations.len() - chunk - self.get_dim(queue) } fn get_res(&self, queue: usize) -> usize { let dim = self.get_dim(queue); dim * self.get_res_save_dim() } fn get_res_save_dim(&self) -> usize { (self.wg_size + 63) / 64 } fn run(self) -> Result<(), GpuError> { let queues = (self.n - self.h + 1) as usize; let mut instruction_buffer = Vec::with_capacity((self.n - self.h) as usize); let mut result_buffer = Vec::with_capacity((self.n - self.h) as usize); for i in 0..queues { let buffer: Buffer = Buffer::builder() .queue(self.queue.clone()) .len(self.wg_size) .flags(flags::MEM_READ_WRITE) .build()?; instruction_buffer.push(buffer); let results: Buffer = Buffer::builder() .queue(self.queue.clone()) .len(self.get_res(i)) .flags(flags::MEM_READ_WRITE) .build()?; result_buffer.push(results); } println!("finished gpu setup"); loop { match self.receiver.recv()? { Message::CpuDone => { self.output_sender.send(Message::CpuDone)?; return Ok(()); } Message::Terminate => { return Ok(()); } Message::HostMessage((id, i, buffer)) => { let i = i as usize; let off = self.get_off(i); let dim = self.get_dim(i); let res = self.get_res(i); let res_size = self.get_res_save_dim(); instruction_buffer[i].write(&buffer).enq()?; //println!("dim: {}", dim); //println!("off: {}", self.get_off(i)); //println!("result size: {}", self.get_res_save_dim()); let kernel = Kernel::builder() .program(&self.program) .name("check") .queue(self.queue.clone()) .global_work_size(dim) .arg(&self.permutations) .arg(&result_buffer[i]) .arg(&instruction_buffer[i]) .arg_local::(self.wg_size) .arg(self.n) .arg(self.w) .arg(off) .build()?; unsafe { kernel .cmd() .queue(&self.queue) .global_work_offset(kernel.default_global_work_offset()) .global_work_size(dim) .local_work_size(self.wg_size) .enq()?; } // (5) Read results from the device into a vector (`::block` not shown): let mut data = vec![0u64; res]; result_buffer[i] .cmd() .queue(&self.queue) .offset(0) .read(&mut data) .enq()?; self.queue.finish()?; self.output_sender .send(Message::ResultMessage(ResultMessage::new( data, off, res_size, self.wg_size, id, )))?; } m => println!("Invalid MessageType {:?} recived by host", m), } } } }