diff options
author | Dennis Kobert <dennis@kobert.dev> | 2021-02-19 00:31:05 +0000 |
---|---|---|
committer | Dennis Kobert <dennis@kobert.dev> | 2021-02-19 00:31:05 +0000 |
commit | 5ae2fdf0e941b3ce13ad35363398782381179dad (patch) | |
tree | c0812c4204f31769bfa3bd571dde6ebe5ff348ad /src/database.rs | |
parent | 90a4cfacbb64750a6779995e91509588f78e9802 (diff) |
Update crates tu use async all the way
Diffstat (limited to 'src/database.rs')
-rw-r--r-- | src/database.rs | 164 |
1 files changed, 94 insertions, 70 deletions
diff --git a/src/database.rs b/src/database.rs index b2cd323..ecb88d0 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,22 +1,19 @@ use crate::errors::Error; -use postgres::{Client, NoTls}; use rand::prelude::*; -use std::sync::{Arc, Mutex}; +use std::{mem::zeroed, sync::Arc}; +use tokio::sync::Mutex; +use tokio_postgres::{Client, NoTls}; +pub static mut _client: Option<Client> = None; lazy_static! { - static ref CLIENT: Arc<Mutex<Client>> = Arc::new(Mutex::new( - Client::connect( - "host=127.0.0.1 user=spotify_intersect password=example dbname=track_db", - NoTls - ) - .expect("failed to connect to database") - )); + static ref CLIENT: Client = unsafe { _client.take().unwrap() }; } -pub fn initialize_db() -> Result<(), Error> { - let mut client = CLIENT.lock()?; - client.batch_execute( - r#" +pub async fn initialize_db() -> Result<(), Error> { + let client = &CLIENT; + client + .batch_execute( + r#" CREATE TABLE IF NOT EXISTS track ( track_id SERIAL PRIMARY KEY, track_code TEXT NOT NULL UNIQUE, @@ -51,92 +48,105 @@ pub fn initialize_db() -> Result<(), Error> { FROM user_track_raw ); "#, - )?; + ) + .await?; Ok(()) } use rspotify::model::track::FullTrack; -pub fn insert_track(user_id: i32, track: FullTrack, weight: i32) -> Result<(), Error> { - let mut client = CLIENT.lock()?; +pub async fn insert_track(user_id: i32, track: FullTrack, weight: i32) -> Result<(), Error> { + let client = &CLIENT; if track.id.is_none() { println!("{:#?}", track); return Err("failed to load get track information".into()); } print!(" {} ", track.id.clone()?); - client.execute( - "INSERT INTO track (track_code, name, artist, popularity) + client + .execute( + "INSERT INTO track (track_code, name, artist, popularity) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING", - &[ - &(track.id.clone()?), - &track.name, - &track.artists[0].name, - &(track.popularity as i32), - ], - )?; - let track_id: i32 = client.query( - "SELECT track_id FROM track where track_code = $1;", - &[&(track.id?)], - )?[0] + &[ + &(track.id.clone()?), + &track.name, + &track.artists[0].name, + &(track.popularity as i32), + ], + ) + .await?; + let track_id: i32 = client + .query( + "SELECT track_id FROM track where track_code = $1;", + &[&(track.id?)], + ) + .await?[0] .get(0); println!("uid: {} tid: {}", user_id, track_id); - client.execute( - " + client + .execute( + " INSERT INTO user_track_raw (track_id, user_id, count) VALUES ($1, $2, $3) ON CONFLICT ON CONSTRAINT track_user_pkey DO NOTHING; ", - &[&track_id, &user_id, &0], - )?; - client.execute( - "UPDATE user_track SET count = count + $3 WHERE track_id = $1 AND user_id = $2;", - &[&track_id, &user_id, &weight], - )?; + &[&track_id, &user_id, &0], + ) + .await?; + client + .execute( + "UPDATE user_track SET count = count + $3 WHERE track_id = $1 AND user_id = $2;", + &[&track_id, &user_id, &weight], + ) + .await?; Ok(()) } -pub fn insert_user(name: &str, lobby: &str) -> Result<i32, Error> { - let mut client = CLIENT.lock()?; - let x = get_lid(lobby, &mut *client)?; +pub async fn insert_user(name: &str, lobby: &str) -> Result<i32, Error> { + let client = &CLIENT; + let x = get_lid(lobby, &*client).await?; client.execute( "INSERT INTO suser (user_name, lobby_id) VALUES ($1, $2) ON CONFLICT (user_name, lobby_id) DO NOTHING;", &[&name, &x], - )?; - let db_user_id = get_uid(name, lobby, &mut *client)?; - client.execute("DELETE FROM user_track WHERE user_id = $1;", &[&db_user_id])?; + ).await?; + let db_user_id = get_uid(name, lobby, &*client).await?; + client + .execute("DELETE FROM user_track WHERE user_id = $1;", &[&db_user_id]) + .await?; Ok(db_user_id) } -fn get_lid(lobby: &str, client: &mut postgres::Client) -> Result<i32, Error> { +async fn get_lid(lobby: &str, client: &Client) -> Result<i32, Error> { let x: i32 = client - .query_one("SELECT lobby_id FROM lobby WHERE token = $1;", &[&lobby])? + .query_one("SELECT lobby_id FROM lobby WHERE token = $1;", &[&lobby]) + .await? .get(0); Ok(x) } -fn get_uid(name: &str, lobby: &str, client: &mut postgres::Client) -> Result<i32, Error> { +async fn get_uid(name: &str, lobby: &str, client: &Client) -> Result<i32, Error> { let x: i32 = client - .query_one("SELECT user_id FROM suser JOIN lobby USING (lobby_id) WHERE user_name = $1 AND token = $2;", &[&name, &lobby])? + .query_one("SELECT user_id FROM suser JOIN lobby USING (lobby_id) WHERE user_name = $1 AND token = $2;", &[&name, &lobby]).await? .get(0); Ok(x) } -pub fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> { - let mut client = CLIENT.lock()?; +pub async fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> { + let client = &CLIENT; let mut songs = String::new(); let names: Vec<String> = names.iter().map(|x| x.to_string()).collect(); - let users = names - .iter() - .fold(String::new(), |a, name| format!("{}, ({})", a, name)); - let users: String = users.chars().skip(2).collect(); - for row in client.query( + //let users: Vec<&str> = names.iter().map(|x| x.as_str()).collect(); + //println!("users to match: {:?}", users); + let placeholders = + (2..=(names.len() + 1)).fold(String::new(), |p, a| format!("{}, (${})", p, a)); + let placeholders: String = placeholders.chars().skip(2).collect(); + let query = format!( " WITH users AS ( SELECT * - FROM ( VALUES $1 AS _ (user_id) + FROM ( VALUES {} ) AS _ (user_id) ) SELECT track_id, name, artist FROM track @@ -146,15 +156,23 @@ pub fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> { JOIN suser USING (user_id) JOIN track USING (track_id) JOIN lobby USING (lobby_id) - WHERE suser.user_name IN (SELECT * FROM users) AND track = $2 + WHERE suser.user_name IN (SELECT * FROM users) AND token = $1 GROUP BY track_id HAVING COUNT(track_id) = (SELECT COUNT(*) FROM users) ORDER BY SUM(score) DESC ) AS _ USING (track_id) ; ", - &[&users, &lobby], - )? { + placeholders + ); + let mut values = vec![lobby]; + values.extend(names); + println!("{:?}", values); + + use futures::{pin_mut, TryStreamExt}; + let mut it = client.query_raw(query.as_str(), values).await?; + pin_mut!(it); + for row in it.try_next().await? { let name: String = row.get(1); let artist: String = row.get(2); songs = format!("{}{} by {}\n", songs, name, artist); @@ -162,33 +180,39 @@ pub fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> { Ok(songs) } -pub fn get_users(lobby: &str) -> Result<String, Error> { - let mut client = CLIENT.lock()?; +pub async fn get_users(lobby: &str) -> Result<String, Error> { + let client = &CLIENT; let mut users = String::new(); - for row in client.query( - "SELECT user_name FROM suser JOIN lobby USING (lobby_id) WHERE token = $1", - &[&lobby], - )? { + for row in client + .query( + "SELECT user_name FROM suser JOIN lobby USING (lobby_id) WHERE token = $1", + &[&lobby], + ) + .await? + { let user: String = row.get(0); users = format!("{}{}\n", users, user); } Ok(users) } -pub fn create_lobby(name: &str) -> Result<String, Error> { - let mut client = CLIENT.lock()?; +pub async fn create_lobby(name: &str) -> Result<String, Error> { + let client = &CLIENT; let mut token = String::new(); while token.is_empty() || client .query_one("SELECT lobby_id FROM lobby WHERE token = $1", &[&token]) + .await .is_ok() { let rand: [u8; 20] = rand::thread_rng().gen(); token = base64::encode_config(&rand, base64::URL_SAFE); } - client.execute( - "INSERT INTO lobby (token, lobby_name) VALUES ($1, $2);", - &[&token, &name], - )?; + client + .execute( + "INSERT INTO lobby (token, lobby_name) VALUES ($1, $2);", + &[&token, &name], + ) + .await?; Ok(token) } |