summaryrefslogtreecommitdiff
path: root/src/database.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/database.rs')
-rw-r--r--src/database.rs164
1 files changed, 94 insertions, 70 deletions
diff --git a/src/database.rs b/src/database.rs
index b2cd323..ecb88d0 100644
--- a/src/database.rs
+++ b/src/database.rs
@@ -1,22 +1,19 @@
use crate::errors::Error;
-use postgres::{Client, NoTls};
use rand::prelude::*;
-use std::sync::{Arc, Mutex};
+use std::{mem::zeroed, sync::Arc};
+use tokio::sync::Mutex;
+use tokio_postgres::{Client, NoTls};
+pub static mut _client: Option<Client> = None;
lazy_static! {
- static ref CLIENT: Arc<Mutex<Client>> = Arc::new(Mutex::new(
- Client::connect(
- "host=127.0.0.1 user=spotify_intersect password=example dbname=track_db",
- NoTls
- )
- .expect("failed to connect to database")
- ));
+ static ref CLIENT: Client = unsafe { _client.take().unwrap() };
}
-pub fn initialize_db() -> Result<(), Error> {
- let mut client = CLIENT.lock()?;
- client.batch_execute(
- r#"
+pub async fn initialize_db() -> Result<(), Error> {
+ let client = &CLIENT;
+ client
+ .batch_execute(
+ r#"
CREATE TABLE IF NOT EXISTS track (
track_id SERIAL PRIMARY KEY,
track_code TEXT NOT NULL UNIQUE,
@@ -51,92 +48,105 @@ pub fn initialize_db() -> Result<(), Error> {
FROM user_track_raw
);
"#,
- )?;
+ )
+ .await?;
Ok(())
}
use rspotify::model::track::FullTrack;
-pub fn insert_track(user_id: i32, track: FullTrack, weight: i32) -> Result<(), Error> {
- let mut client = CLIENT.lock()?;
+pub async fn insert_track(user_id: i32, track: FullTrack, weight: i32) -> Result<(), Error> {
+ let client = &CLIENT;
if track.id.is_none() {
println!("{:#?}", track);
return Err("failed to load get track information".into());
}
print!(" {} ", track.id.clone()?);
- client.execute(
- "INSERT INTO track (track_code, name, artist, popularity)
+ client
+ .execute(
+ "INSERT INTO track (track_code, name, artist, popularity)
VALUES ($1, $2, $3, $4)
ON CONFLICT DO NOTHING",
- &[
- &(track.id.clone()?),
- &track.name,
- &track.artists[0].name,
- &(track.popularity as i32),
- ],
- )?;
- let track_id: i32 = client.query(
- "SELECT track_id FROM track where track_code = $1;",
- &[&(track.id?)],
- )?[0]
+ &[
+ &(track.id.clone()?),
+ &track.name,
+ &track.artists[0].name,
+ &(track.popularity as i32),
+ ],
+ )
+ .await?;
+ let track_id: i32 = client
+ .query(
+ "SELECT track_id FROM track where track_code = $1;",
+ &[&(track.id?)],
+ )
+ .await?[0]
.get(0);
println!("uid: {} tid: {}", user_id, track_id);
- client.execute(
- "
+ client
+ .execute(
+ "
INSERT INTO user_track_raw (track_id, user_id, count)
VALUES ($1, $2, $3)
ON CONFLICT
ON CONSTRAINT track_user_pkey
DO NOTHING;
",
- &[&track_id, &user_id, &0],
- )?;
- client.execute(
- "UPDATE user_track SET count = count + $3 WHERE track_id = $1 AND user_id = $2;",
- &[&track_id, &user_id, &weight],
- )?;
+ &[&track_id, &user_id, &0],
+ )
+ .await?;
+ client
+ .execute(
+ "UPDATE user_track SET count = count + $3 WHERE track_id = $1 AND user_id = $2;",
+ &[&track_id, &user_id, &weight],
+ )
+ .await?;
Ok(())
}
-pub fn insert_user(name: &str, lobby: &str) -> Result<i32, Error> {
- let mut client = CLIENT.lock()?;
- let x = get_lid(lobby, &mut *client)?;
+pub async fn insert_user(name: &str, lobby: &str) -> Result<i32, Error> {
+ let client = &CLIENT;
+ let x = get_lid(lobby, &*client).await?;
client.execute(
"INSERT INTO suser (user_name, lobby_id) VALUES ($1, $2) ON CONFLICT (user_name, lobby_id) DO NOTHING;",
&[&name, &x],
- )?;
- let db_user_id = get_uid(name, lobby, &mut *client)?;
- client.execute("DELETE FROM user_track WHERE user_id = $1;", &[&db_user_id])?;
+ ).await?;
+ let db_user_id = get_uid(name, lobby, &*client).await?;
+ client
+ .execute("DELETE FROM user_track WHERE user_id = $1;", &[&db_user_id])
+ .await?;
Ok(db_user_id)
}
-fn get_lid(lobby: &str, client: &mut postgres::Client) -> Result<i32, Error> {
+async fn get_lid(lobby: &str, client: &Client) -> Result<i32, Error> {
let x: i32 = client
- .query_one("SELECT lobby_id FROM lobby WHERE token = $1;", &[&lobby])?
+ .query_one("SELECT lobby_id FROM lobby WHERE token = $1;", &[&lobby])
+ .await?
.get(0);
Ok(x)
}
-fn get_uid(name: &str, lobby: &str, client: &mut postgres::Client) -> Result<i32, Error> {
+async fn get_uid(name: &str, lobby: &str, client: &Client) -> Result<i32, Error> {
let x: i32 = client
- .query_one("SELECT user_id FROM suser JOIN lobby USING (lobby_id) WHERE user_name = $1 AND token = $2;", &[&name, &lobby])?
+ .query_one("SELECT user_id FROM suser JOIN lobby USING (lobby_id) WHERE user_name = $1 AND token = $2;", &[&name, &lobby]).await?
.get(0);
Ok(x)
}
-pub fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> {
- let mut client = CLIENT.lock()?;
+pub async fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> {
+ let client = &CLIENT;
let mut songs = String::new();
let names: Vec<String> = names.iter().map(|x| x.to_string()).collect();
- let users = names
- .iter()
- .fold(String::new(), |a, name| format!("{}, ({})", a, name));
- let users: String = users.chars().skip(2).collect();
- for row in client.query(
+ //let users: Vec<&str> = names.iter().map(|x| x.as_str()).collect();
+ //println!("users to match: {:?}", users);
+ let placeholders =
+ (2..=(names.len() + 1)).fold(String::new(), |p, a| format!("{}, (${})", p, a));
+ let placeholders: String = placeholders.chars().skip(2).collect();
+ let query = format!(
"
WITH users AS (
SELECT *
- FROM ( VALUES $1 AS _ (user_id)
+ FROM ( VALUES {} ) AS _ (user_id)
)
SELECT track_id, name, artist
FROM track
@@ -146,15 +156,23 @@ pub fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> {
JOIN suser USING (user_id)
JOIN track USING (track_id)
JOIN lobby USING (lobby_id)
- WHERE suser.user_name IN (SELECT * FROM users) AND track = $2
+ WHERE suser.user_name IN (SELECT * FROM users) AND token = $1
GROUP BY track_id
HAVING COUNT(track_id) = (SELECT COUNT(*) FROM users)
ORDER BY SUM(score) DESC
) AS _ USING (track_id)
;
",
- &[&users, &lobby],
- )? {
+ placeholders
+ );
+ let mut values = vec![lobby];
+ values.extend(names);
+ println!("{:?}", values);
+
+ use futures::{pin_mut, TryStreamExt};
+ let mut it = client.query_raw(query.as_str(), values).await?;
+ pin_mut!(it);
+ for row in it.try_next().await? {
let name: String = row.get(1);
let artist: String = row.get(2);
songs = format!("{}{} by {}\n", songs, name, artist);
@@ -162,33 +180,39 @@ pub fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> {
Ok(songs)
}
-pub fn get_users(lobby: &str) -> Result<String, Error> {
- let mut client = CLIENT.lock()?;
+pub async fn get_users(lobby: &str) -> Result<String, Error> {
+ let client = &CLIENT;
let mut users = String::new();
- for row in client.query(
- "SELECT user_name FROM suser JOIN lobby USING (lobby_id) WHERE token = $1",
- &[&lobby],
- )? {
+ for row in client
+ .query(
+ "SELECT user_name FROM suser JOIN lobby USING (lobby_id) WHERE token = $1",
+ &[&lobby],
+ )
+ .await?
+ {
let user: String = row.get(0);
users = format!("{}{}\n", users, user);
}
Ok(users)
}
-pub fn create_lobby(name: &str) -> Result<String, Error> {
- let mut client = CLIENT.lock()?;
+pub async fn create_lobby(name: &str) -> Result<String, Error> {
+ let client = &CLIENT;
let mut token = String::new();
while token.is_empty()
|| client
.query_one("SELECT lobby_id FROM lobby WHERE token = $1", &[&token])
+ .await
.is_ok()
{
let rand: [u8; 20] = rand::thread_rng().gen();
token = base64::encode_config(&rand, base64::URL_SAFE);
}
- client.execute(
- "INSERT INTO lobby (token, lobby_name) VALUES ($1, $2);",
- &[&token, &name],
- )?;
+ client
+ .execute(
+ "INSERT INTO lobby (token, lobby_name) VALUES ($1, $2);",
+ &[&token, &name],
+ )
+ .await?;
Ok(token)
}