summaryrefslogtreecommitdiff
path: root/src/solvers/gpu/manager.rs
blob: e210af2b2e5de5255cce00bbfa89f477d93ae2b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
use super::{CheckRequest, Message};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::thread::JoinHandle;

#[derive(Debug)]
struct RequestBuffer {
    mask_buff: Vec<u64>,
    row_buff: Vec<Vec<u32>>,
    pointer: usize,
}

impl RequestBuffer {
    pub fn new(size: usize) -> Self {
        RequestBuffer {
            mask_buff: vec![0; size],
            row_buff: vec![Vec::new(); size],
            pointer: 0,
        }
    }
    pub fn read(&mut self, request: CheckRequest) -> Option<(&[u64], &[Vec<u32>])> {
        self.mask_buff[self.pointer] = request.bitmask;
        self.row_buff[self.pointer] = request.rows;
        self.pointer += 1;
        if self.pointer == self.mask_buff.len() {
            self.pointer = 0;
            return Some((self.mask_buff.as_ref(), self.row_buff.as_ref()));
        }
        None
    }
}

pub struct OclManager {
    job_id: u64,
    host_sender: Sender<Message>,
    output_sender: Sender<Message>,
    receiver: Receiver<Message>,
    buffers: Vec<RequestBuffer>,
    output_handle: JoinHandle<()>,
    host_handle: JoinHandle<()>,
}

impl OclManager {
    pub fn launch_sevice(
        permutations: &[Vec<u32>],
        permutations_mask: &[u64],
        n: u32,
        // Workgroup size, set to 0 for max
        wg_size: u32,
    ) -> (Sender<Message>, JoinHandle<()>) {
        let (h, w) = crate::solvers::wall_stats(n);
        let src = include_str!("check.cl");
        let (output_sender, output_handle) =
            super::output::Output::launch_sevice(permutations, permutations_mask);
        let (host_sender, host_handle) = super::host::Host::launch_sevice(
            permutations_mask,
            n,
            h,
            w,
            wg_size as usize,
            src,
            output_sender.clone(),
        )
        .unwrap();

        let (sender, receiver) = channel();

        let mut buffers = Vec::with_capacity((n - h + 1) as usize);
        for _ in 0..=(n - h) {
            buffers.push(RequestBuffer::new(wg_size as usize));
        }

        let manager = Self {
            job_id: 0,
            host_sender,
            output_sender,
            receiver,
            buffers,
            output_handle,
            host_handle,
        };
        (
            sender,
            std::thread::Builder::new()
                .name("GPU Manager Deamon".into())
                .spawn(move || {
                    manager.run();
                })
                .unwrap(),
        )
    }

    fn run(mut self) {
        loop {
            match self.receiver.recv().expect("Channel to GPU Manager broke") {
                Message::CheckRequest(request) => {
                    let queue = request.queue;
                    println!("{}", queue);
                    if let Some(buffer) = self.buffers[queue as usize].read(request) {
                        self.host_sender
                            .send(Message::HostMessage((self.job_id, queue, buffer.0.into())))
                            .unwrap();
                        self.output_sender
                            .send(Message::OutputMessage((self.job_id, buffer.1.into())))
                            .unwrap();
                        self.job_id += 1;
                    }
                }
                Message::Terminate => {
                    panic!("flush buffers");
                    self.host_sender.send(Message::Terminate);
                    self.host_handle.join();
                    self.output_sender.send(Message::Terminate);
                    self.output_handle.join();
                    return;
                }
                _ => println!("Invalid MessageType"),
            }
        }
    }
}