diff options
author | Dennis Kobert <dennis@kobert.dev> | 2021-02-17 21:22:58 +0100 |
---|---|---|
committer | Dennis Kobert <dennis@kobert.dev> | 2021-02-17 21:22:58 +0100 |
commit | 90a4cfacbb64750a6779995e91509588f78e9802 (patch) | |
tree | 60f0267f46457b658e4c3826c16f4d3fdfc4a89f /src | |
parent | 1d68281dca016c63ef5cb96a2dceed8f3dbdc950 (diff) |
Add lobby concept
Diffstat (limited to 'src')
-rw-r--r-- | src/database.rs | 98 | ||||
-rw-r--r-- | src/errors.rs | 8 | ||||
-rw-r--r-- | src/main.rs | 10 | ||||
-rw-r--r-- | src/serve.rs | 31 | ||||
-rw-r--r-- | src/spotify.rs | 45 |
5 files changed, 126 insertions, 66 deletions
diff --git a/src/database.rs b/src/database.rs index 94a07d6..b2cd323 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,11 +1,15 @@ use crate::errors::Error; use postgres::{Client, NoTls}; +use rand::prelude::*; use std::sync::{Arc, Mutex}; lazy_static! { static ref CLIENT: Arc<Mutex<Client>> = Arc::new(Mutex::new( - Client::connect("host=track_db user=postgres password=example", NoTls) - .expect("failed to connect to database") + Client::connect( + "host=127.0.0.1 user=spotify_intersect password=example dbname=track_db", + NoTls + ) + .expect("failed to connect to database") )); } @@ -14,15 +18,22 @@ pub fn initialize_db() -> Result<(), Error> { client.batch_execute( r#" CREATE TABLE IF NOT EXISTS track ( - track_id SERIAL PRIMARY KEY, + track_id SERIAL PRIMARY KEY, track_code TEXT NOT NULL UNIQUE, name TEXT NOT NULL, artist TEXT NOT NULL, - popularity int DEFAULT 50 + popularity int DEFAULT 50 + ); + CREATE TABLE IF NOT EXISTS lobby ( + lobby_id SERIAL PRIMARY KEY, + token TEXT NOT NULL UNIQUE, + lobby_name TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS suser ( user_id SERIAL PRIMARY KEY, - user_name TEXT NOT NULL UNIQUE + user_name TEXT NOT NULL, + lobby_id int REFERENCES lobby (lobby_id) ON UPDATE CASCADE ON DELETE CASCADE, + CONSTRAINT suser_name_lobby UNIQUE (lobby_id, user_name) ); CREATE TABLE IF NOT EXISTS user_track_raw ( track_id int REFERENCES track (track_id) ON UPDATE CASCADE ON DELETE CASCADE, @@ -44,7 +55,7 @@ pub fn initialize_db() -> Result<(), Error> { Ok(()) } -use rspotify::spotify::model::track::FullTrack; +use rspotify::model::track::FullTrack; pub fn insert_track(user_id: i32, track: FullTrack, weight: i32) -> Result<(), Error> { let mut client = CLIENT.lock()?; @@ -54,8 +65,8 @@ pub fn insert_track(user_id: i32, track: FullTrack, weight: i32) -> Result<(), E } print!(" {} ", track.id.clone()?); client.execute( - "INSERT INTO track (track_code, name, artist, popularity) - VALUES ($1, $2, $3, $4) + "INSERT INTO track (track_code, name, artist, popularity) + VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING", &[ &(track.id.clone()?), @@ -72,10 +83,10 @@ pub fn insert_track(user_id: i32, track: FullTrack, weight: i32) -> Result<(), E println!("uid: {} tid: {}", user_id, track_id); client.execute( " - INSERT INTO user_track_raw (track_id, user_id, count) - VALUES ($1, $2, $3) - ON CONFLICT - ON CONSTRAINT track_user_pkey + INSERT INTO user_track_raw (track_id, user_id, count) + VALUES ($1, $2, $3) + ON CONFLICT + ON CONSTRAINT track_user_pkey DO NOTHING; ", &[&track_id, &user_id, &0], @@ -87,36 +98,45 @@ pub fn insert_track(user_id: i32, track: FullTrack, weight: i32) -> Result<(), E Ok(()) } -pub fn insert_user(name: &str) -> Result<i32, Error> { +pub fn insert_user(name: &str, lobby: &str) -> Result<i32, Error> { let mut client = CLIENT.lock()?; + let x = get_lid(lobby, &mut *client)?; client.execute( - "INSERT INTO suser (user_name) VALUES ($1) ON CONFLICT (user_name) DO NOTHING;", - &[&name], + "INSERT INTO suser (user_name, lobby_id) VALUES ($1, $2) ON CONFLICT (user_name, lobby_id) DO NOTHING;", + &[&name, &x], )?; - let db_user_id = get_uid(name, &mut *client)?; + let db_user_id = get_uid(name, lobby, &mut *client)?; client.execute("DELETE FROM user_track WHERE user_id = $1;", &[&db_user_id])?; Ok(db_user_id) } -fn get_uid(name: &str, client: &mut postgres::Client) -> Result<i32, Error> { +fn get_lid(lobby: &str, client: &mut postgres::Client) -> Result<i32, Error> { let x: i32 = client - .query_one("SELECT user_id FROM suser where user_name = $1;", &[&name])? + .query_one("SELECT lobby_id FROM lobby WHERE token = $1;", &[&lobby])? .get(0); Ok(x) } -pub fn match_users(name1: String, name2: String) -> Result<String, Error> { +fn get_uid(name: &str, lobby: &str, client: &mut postgres::Client) -> Result<i32, Error> { + let x: i32 = client + .query_one("SELECT user_id FROM suser JOIN lobby USING (lobby_id) WHERE user_name = $1 AND token = $2;", &[&name, &lobby])? + .get(0); + Ok(x) +} + +pub fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> { let mut client = CLIENT.lock()?; let mut songs = String::new(); + let names: Vec<String> = names.iter().map(|x| x.to_string()).collect(); + let users = names + .iter() + .fold(String::new(), |a, name| format!("{}, ({})", a, name)); + let users: String = users.chars().skip(2).collect(); for row in client.query( " WITH users AS ( SELECT * - FROM ( VALUES - - ($1), ($2) - - ) AS _ (user_id) + FROM ( VALUES $1 AS _ (user_id) ) SELECT track_id, name, artist FROM track @@ -125,14 +145,15 @@ pub fn match_users(name1: String, name2: String) -> Result<String, Error> { FROM user_track JOIN suser USING (user_id) JOIN track USING (track_id) - WHERE suser.user_name IN (SELECT * FROM users) + JOIN lobby USING (lobby_id) + WHERE suser.user_name IN (SELECT * FROM users) AND track = $2 GROUP BY track_id HAVING COUNT(track_id) = (SELECT COUNT(*) FROM users) ORDER BY SUM(score) DESC ) AS _ USING (track_id) ; ", - &[&name1.as_str(), &name2.as_str()], + &[&users, &lobby], )? { let name: String = row.get(1); let artist: String = row.get(2); @@ -141,12 +162,33 @@ pub fn match_users(name1: String, name2: String) -> Result<String, Error> { Ok(songs) } -pub fn get_users() -> Result<String, Error> { +pub fn get_users(lobby: &str) -> Result<String, Error> { let mut client = CLIENT.lock()?; let mut users = String::new(); - for row in client.query("SELECT user_name FROM suser", &[])? { + for row in client.query( + "SELECT user_name FROM suser JOIN lobby USING (lobby_id) WHERE token = $1", + &[&lobby], + )? { let user: String = row.get(0); users = format!("{}{}\n", users, user); } Ok(users) } + +pub fn create_lobby(name: &str) -> Result<String, Error> { + let mut client = CLIENT.lock()?; + let mut token = String::new(); + while token.is_empty() + || client + .query_one("SELECT lobby_id FROM lobby WHERE token = $1", &[&token]) + .is_ok() + { + let rand: [u8; 20] = rand::thread_rng().gen(); + token = base64::encode_config(&rand, base64::URL_SAFE); + } + client.execute( + "INSERT INTO lobby (token, lobby_name) VALUES ($1, $2);", + &[&token, &name], + )?; + Ok(token) +} diff --git a/src/errors.rs b/src/errors.rs index 8fb2793..544c391 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,8 +1,8 @@ use rocket::http::ContentType; use rocket::request::Request; use rocket::response::{self, Responder, Response}; -use rspotify::spotify::client::ApiError; -use rspotify::spotify::oauth2::SpotifyOAuth; +use rspotify::client::ApiError; +use rspotify::oauth2::SpotifyOAuth; use std::collections::HashMap; use std::io::Cursor; use std::sync::{MutexGuard, PoisonError}; @@ -49,7 +49,7 @@ impl<'a> From<std::option::NoneError> for Error { Error::Misc(format!("tried to unwrap none at: {:?}", error)) } } -impl<'a> Responder<'a> for Error { +impl<'a> Responder<'a, 'a> for Error { fn respond_to(self, _: &Request) -> response::Result<'a> { let response = match self { Error::Postgres(e) => format!("DB Error: {:?}", e), @@ -59,7 +59,7 @@ impl<'a> Responder<'a> for Error { Response::build() .header(ContentType::Plain) .status(rocket::http::Status::raw(500)) - .sized_body(Cursor::new(response)) + .sized_body(response.len(), Cursor::new(response)) .ok() } } diff --git a/src/main.rs b/src/main.rs index 9dd4551..eb5af4c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,8 @@ mod errors; mod serve; mod spotify; -fn main() { +#[rocket::main] +async fn main() { database::initialize_db().expect("failed to initialize_db"); println!("connected with db"); rocket::ignite() @@ -20,8 +21,11 @@ fn main() { serve::token, serve::get_tracks, serve::match_users, - serve::get_users + serve::get_users, + serve::create_lobby ], ) - .launch(); + .launch() + .await + .unwrap(); } diff --git a/src/serve.rs b/src/serve.rs index 53a3cac..ee0e5ff 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -3,11 +3,11 @@ use crate::errors::Error; use crate::spotify; use rocket::response::Redirect; -#[get("/callback/<name>/<url>")] -pub fn get_tracks(name: String, url: String) -> Result<(), Error> { - let (spotify_uid, spotify_client) = spotify::auth_user(name.as_ref(), url)?; - let uid = database::insert_user(spotify_uid.as_ref())?; - spotify::load_profile(uid, spotify_uid.as_ref(), spotify_client) +#[get("/callback/<name>/<lobby>/<url>")] +pub async fn get_tracks(name: String, lobby: String, url: String) -> Result<(), Error> { + let (spotify_uid, spotify_client) = spotify::auth_user(name.as_ref(), url).await?; + let uid = database::insert_user(spotify_uid.as_ref(), lobby.as_ref())?; + spotify::load_profile(uid, spotify_uid.as_ref(), spotify_client).await } #[get("/token/<name>")] @@ -15,12 +15,21 @@ pub fn token(name: String) -> Result<Redirect, Error> { Ok(Redirect::to(spotify::token(name)?)) } -#[get("/match/<name1>/<name2>")] -pub fn match_users(name1: String, name2: String) -> Result<String, Error> { - database::match_users(name1, name2) +#[get("/match/<lobby>/<names>")] +pub fn match_users(lobby: String, names: String) -> Result<String, Error> { + //let bytes = base64::decode_config(names, base64::URL_SAFE).unwrap(); + //let names = String::from_utf8(bytes).unwrap(); + let names: Vec<&str> = names.split(',').collect(); + database::match_users(lobby, names.as_slice()) } -#[get("/user")] -pub fn get_users() -> Result<String, Error> { - database::get_users() +#[get("/users/<lobby>")] +pub fn get_users(lobby: String) -> Result<String, Error> { + database::get_users(lobby.as_str()) +} + +//#[post("/lobby", format = "application/json", data = "<name>")] +#[post("/lobby/<name>")] +pub fn create_lobby(name: String) -> Result<String, Error> { + database::create_lobby(name.as_str()) } diff --git a/src/spotify.rs b/src/spotify.rs index 596dec3..0e18c43 100644 --- a/src/spotify.rs +++ b/src/spotify.rs @@ -1,9 +1,13 @@ use crate::database; use crate::errors::Error; use lazy_static::lazy_static; -use rspotify::spotify::client::{ApiError, Spotify}; -use rspotify::spotify::oauth2::{SpotifyClientCredentials, SpotifyOAuth}; -use rspotify::spotify::util::process_token; +use rspotify::client::{ApiError, Spotify}; +use rspotify::model::page::Page; +use rspotify::model::playlist::*; +use rspotify::model::track::*; +use rspotify::oauth2::{SpotifyClientCredentials, SpotifyOAuth}; +use rspotify::senum::TimeRange; +use rspotify::util::process_token; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -19,13 +23,14 @@ macro_rules! get_items { $index = 0; let mut result: Vec<$t> = Vec::new(); loop { - match $spotify_call { - Ok(mut items) => { + let res: Result<Page<$t>, failure::Error> = $spotify_call.await; + match res { + Ok(mut page) => { $index += CHUNK_SIZE; - if items.items.is_empty() { + if page.items.is_empty() { break; } - result.append(&mut items.items); + result.append(&mut page.items); } Err(e) => match e.downcast::<ApiError>() { Ok(ApiError::RateLimited(x)) => { @@ -43,10 +48,7 @@ macro_rules! get_items { }}; } -use rspotify::spotify::model::playlist::*; -use rspotify::spotify::model::track::*; -use rspotify::spotify::senum::TimeRange; -pub fn load_profile(db_uid: i32, spotify_uid: &str, spotify: Spotify) -> Result<(), Error> { +pub async fn load_profile(db_uid: i32, spotify_uid: &str, spotify: Spotify) -> Result<(), Error> { let mut index; let playlists = get_items!( SimplifiedPlaylist, @@ -87,8 +89,8 @@ pub fn load_profile(db_uid: i32, spotify_uid: &str, spotify: Spotify) -> Result< None, ) ); - for track in tracks { - if let Err(e) = database::insert_track(db_uid, track.track, 1) { + for track in tracks.iter().map(|x| x.track.clone()).flatten() { + if let Err(e) = database::insert_track(db_uid, track, 1) { println!("failed to load track to db: {:?}", e) }; } @@ -96,13 +98,16 @@ pub fn load_profile(db_uid: i32, spotify_uid: &str, spotify: Spotify) -> Result< Ok(()) } -pub fn auth_user(name: &str, url: String) -> Result<(String, Spotify), Error> { - let mut guard = CACHE.lock()?; - let mut oauth = guard.remove(name)?; +pub async fn auth_user(name: &str, url: String) -> Result<(String, Spotify), Error> { + let mut oauth = { + let mut guard = CACHE.lock()?; + guard.remove(name)? + }; println!("auth: {:?} url: {}", oauth, url); - let token_info = process_token(&mut oauth, &mut ("?code=".to_owned() + url.as_ref())); + let mut token_string = format!("?code={}", url); + let token_info = process_token(&mut oauth, &mut token_string); let client_credential = SpotifyClientCredentials::default() - .token_info(token_info?) + .token_info(token_info.await?) .build(); let spotify = Spotify::default() @@ -110,14 +115,14 @@ pub fn auth_user(name: &str, url: String) -> Result<(String, Spotify), Error> { .build(); let user_id = spotify .current_user() + .await .map_err(|e| format!("failed to load currentuser {:?}", e))? .id; Ok((user_id, spotify)) } -#[get("/token/<name>")] pub fn token(name: String) -> Result<String, Error> { - let state = rspotify::spotify::util::generate_random_string(16); + let state = rspotify::util::generate_random_string(16); let oauth = SpotifyOAuth::default(); let oauth = oauth .scope("playlist-read-private, playlist-read-collaborative, user-read-private, user-follow-read, user-library-read") |