summaryrefslogtreecommitdiff
path: root/src/solvers/gpu/manager.rs
blob: 1dd6a4d02b3b61da0480d470521df9ce9351380f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
use std::sync::mpsc::{Receiver, Sender, channel};
use std::thread::JoinHandle;
use super::*;

#[derive(Debug)]
struct RequestBuffer {
    mask_buff: Vec<u64>,
    pointer: usize,
}

impl RequestBuffer {
    pub fn new(size: usize) -> Self {
        RequestBuffer {
            mask_buff: vec![0; size],
            pointer: 0,
        }
    }
    pub fn read(&mut self, request: CheckRequest) -> Option<&[u64]> {
        self.mask_buff[self.pointer] = request.bitmask;
        self.pointer += 1;
        if self.pointer == self.mask_buff.len() {
            self.pointer = 0;
            return Some(self.mask_buff.as_ref());
        }
        None
    }
}

pub struct OclManager {
    job_id: u64,
    host_sender: Sender<Message>,
    output_sender: Sender<Message>,
    reciever: Receiver<Message>,
    buffers: Vec<RequestBuffer>,
    output_handle: JoinHandle<String>,
    host_handle: JoinHandle<String>,
}

impl OclManager {
    pub fn launch_sevice(
        permutations: &[&[u32]],
        permutations_mask: &[u64],
        n: u32,
        // Workgroup size, set to 0 for max
        wg_size: u32,
    ) -> (Sender<Message>, JoinHandle<String>) {
        let (h, w) = crate::solvers::wall_stats(n);
        let src = include_str!("check.cl");
        let (output_sender, output_handle) =
            super::output::Output::launch_sevice(permutations, permutations_mask, n, h, w);
        let (host_sender, host_handle) =
            super::host::Host::launch_sevice(permutations_mask, n, h, w, wg_size as usize, src);

        let (receiver, sender) = channel();

        let mut buffers = Vec::with_capacity((n - h + 1) as usize);
        for _ in 0..=(n - h) {
            buffers.push(RequestBuffer::new(wg_size as usize));
        }

        let manager = Self {
            0,
            host_sender,
            output_sender,
            receiver,
            buffers,
            output_handle,
            host_handle,
        }
        (sender, 
        std::thread::Builder::new()
            .name("GPU Manager Deamon".into())
            .spawn(move || {
                manager.run();
            })
            .unwrap())
        
    }

    fn run(mut self) {
        loop {
            match self.reciever.recv().expect("Channel to GPU Manager broke") {
                Message::CheckRequest(request) => {
                    if let Some(buffer) = self.buffers[request.queue as usize].read(request) {
                        self.host_sender
                            .send(Message::HostMessage((self.job_id, buffer.0.into())));
                        self.output_sender
                            .send(Message::OutputMessage((self.job_id, buffer.1.into())));
                        self.job_id += 1;
                    }
                }
                Message::Terminate => {
                    panic!("flush buffers");
                    self.host_sender.send(Message::Terminate);
                    self.host_handle.join();
                    self.output_sender.send(Message::Terminate);
                    self.output_handle.join();
                    return;
                }
                _ => println!("Invalid MessageType"),
            }
        }
    }
}