use crate::errors::Error; use rand::prelude::*; use std::{mem::zeroed, sync::Arc}; use tokio::sync::Mutex; use tokio_postgres::{Client, NoTls}; pub static mut _client: Option = None; lazy_static! { static ref CLIENT: Client = unsafe { _client.take().unwrap() }; } pub async fn initialize_db() -> Result<(), Error> { let client = &CLIENT; client .batch_execute( r#" CREATE TABLE IF NOT EXISTS track ( track_id SERIAL PRIMARY KEY, track_code TEXT NOT NULL UNIQUE, name TEXT NOT NULL, artist TEXT NOT NULL, popularity int DEFAULT 50 ); CREATE TABLE IF NOT EXISTS lobby ( lobby_id SERIAL PRIMARY KEY, token TEXT NOT NULL UNIQUE, lobby_name TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS suser ( user_id SERIAL PRIMARY KEY, user_name TEXT NOT NULL, lobby_id int REFERENCES lobby (lobby_id) ON UPDATE CASCADE ON DELETE CASCADE, CONSTRAINT suser_name_lobby UNIQUE (lobby_id, user_name) ); CREATE TABLE IF NOT EXISTS user_track_raw ( track_id int REFERENCES track (track_id) ON UPDATE CASCADE ON DELETE CASCADE, user_id int REFERENCES suser (user_id) ON UPDATE CASCADE ON DELETE CASCADE, count int NOT NULL DEFAULT 1, CONSTRAINT track_user_pkey PRIMARY KEY (track_id, user_id) ); CREATE OR REPLACE VIEW user_track AS ( SELECT *, "count" / ( SELECT SUM("count") FROM user_track_raw AS ut WHERE ut.user_id = user_id )::decimal AS score FROM user_track_raw ); "#, ) .await?; Ok(()) } use rspotify::model::track::FullTrack; pub async fn insert_track(user_id: i32, track: FullTrack, weight: i32) -> Result<(), Error> { let client = &CLIENT; if track.id.is_none() { println!("{:#?}", track); return Err("failed to load get track information".into()); } print!(" {} ", track.id.clone().unwrap()); client .execute( "INSERT INTO track (track_code, name, artist, popularity) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING", &[ &(track.id.clone().unwrap()), &track.name, &track.artists[0].name, &(track.popularity as i32), ], ) .await?; let track_id: i32 = client .query( "SELECT track_id FROM track where track_code = $1;", &[&(track.id.unwrap())], ) .await?[0] .get(0); println!("uid: {} tid: {}", user_id, track_id); client .execute( " INSERT INTO user_track_raw (track_id, user_id, count) VALUES ($1, $2, $3) ON CONFLICT ON CONSTRAINT track_user_pkey DO NOTHING; ", &[&track_id, &user_id, &0], ) .await?; client .execute( "UPDATE user_track SET count = count + $3 WHERE track_id = $1 AND user_id = $2;", &[&track_id, &user_id, &weight], ) .await?; Ok(()) } pub async fn insert_user(name: &str, lobby: &str) -> Result { let client = &CLIENT; let x = get_lid(lobby, &*client).await?; client.execute( "INSERT INTO suser (user_name, lobby_id) VALUES ($1, $2) ON CONFLICT (user_name, lobby_id) DO NOTHING;", &[&name, &x], ).await?; let db_user_id = get_uid(name, lobby, &*client).await?; client .execute("DELETE FROM user_track WHERE user_id = $1;", &[&db_user_id]) .await?; Ok(db_user_id) } async fn get_lid(lobby: &str, client: &Client) -> Result { let x: i32 = client .query_one("SELECT lobby_id FROM lobby WHERE token = $1;", &[&lobby]) .await? .get(0); Ok(x) } async fn get_uid(name: &str, lobby: &str, client: &Client) -> Result { let x: i32 = client .query_one("SELECT user_id FROM suser JOIN lobby USING (lobby_id) WHERE user_name = $1 AND token = $2;", &[&name, &lobby]).await? .get(0); Ok(x) } pub async fn match_users(lobby: String, names: &[&str]) -> Result { let client = &CLIENT; let mut songs = String::new(); let names: Vec = names.iter().map(|x| x.trim().to_string()).collect(); //let users: Vec<&str> = names.iter().map(|x| x.as_str()).collect(); //println!("users to match: {:?}", users); let placeholders = (2..=(names.len() + 1)).fold(String::new(), |p, a| format!("{}, (${})", p, a)); let placeholders: String = placeholders.chars().skip(2).collect(); let query = format!( " WITH users AS ( SELECT * FROM ( VALUES {} ) AS _ (user_id) ) SELECT track_id, name, artist FROM track JOIN ( SELECT track_id FROM user_track JOIN suser USING (user_id) JOIN track USING (track_id) JOIN lobby USING (lobby_id) WHERE suser.user_name IN (SELECT * FROM users) AND token = $1 GROUP BY track_id HAVING COUNT(track_id) = (SELECT COUNT(*) FROM users) ORDER BY EXP(SUM(LN(score))) DESC ) AS _ USING (track_id) ; ", placeholders ); let mut values = vec![lobby]; values.extend(names); println!("{:?}", values); use futures::{pin_mut, TryStreamExt}; let it = client.query_raw(query.as_str(), values).await?; pin_mut!(it); while let Some(row) = it.try_next().await? { let name: String = row.get(1); let artist: String = row.get(2); songs = format!("{}{} by {}\n", songs, name, artist); } Ok(songs) } pub async fn get_users(lobby: &str) -> Result { let client = &CLIENT; let mut users = String::new(); for row in client .query( "SELECT user_name FROM suser JOIN lobby USING (lobby_id) WHERE token = $1", &[&lobby], ) .await? { let user: String = row.get(0); users = format!("{}{}\n", users, user); } Ok(users) } pub async fn create_lobby(name: &str) -> Result { let client = &CLIENT; let mut token = String::new(); while token.is_empty() || client .query_one("SELECT lobby_id FROM lobby WHERE token = $1", &[&token]) .await .is_ok() { let rand: [u8; 20] = rand::thread_rng().gen(); token = base64::encode_config(&rand, base64::URL_SAFE); } client .execute( "INSERT INTO lobby (token, lobby_name) VALUES ($1, $2);", &[&token, &name], ) .await?; Ok(token) }