diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/database.rs | 164 | ||||
-rw-r--r-- | src/errors.rs | 20 | ||||
-rw-r--r-- | src/main.rs | 47 | ||||
-rw-r--r-- | src/serve.rs | 16 | ||||
-rw-r--r-- | src/spotify.rs | 95 |
5 files changed, 208 insertions, 134 deletions
diff --git a/src/database.rs b/src/database.rs index b2cd323..ecb88d0 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,22 +1,19 @@ use crate::errors::Error; -use postgres::{Client, NoTls}; use rand::prelude::*; -use std::sync::{Arc, Mutex}; +use std::{mem::zeroed, sync::Arc}; +use tokio::sync::Mutex; +use tokio_postgres::{Client, NoTls}; +pub static mut _client: Option<Client> = None; lazy_static! { - static ref CLIENT: Arc<Mutex<Client>> = Arc::new(Mutex::new( - Client::connect( - "host=127.0.0.1 user=spotify_intersect password=example dbname=track_db", - NoTls - ) - .expect("failed to connect to database") - )); + static ref CLIENT: Client = unsafe { _client.take().unwrap() }; } -pub fn initialize_db() -> Result<(), Error> { - let mut client = CLIENT.lock()?; - client.batch_execute( - r#" +pub async fn initialize_db() -> Result<(), Error> { + let client = &CLIENT; + client + .batch_execute( + r#" CREATE TABLE IF NOT EXISTS track ( track_id SERIAL PRIMARY KEY, track_code TEXT NOT NULL UNIQUE, @@ -51,92 +48,105 @@ pub fn initialize_db() -> Result<(), Error> { FROM user_track_raw ); "#, - )?; + ) + .await?; Ok(()) } use rspotify::model::track::FullTrack; -pub fn insert_track(user_id: i32, track: FullTrack, weight: i32) -> Result<(), Error> { - let mut client = CLIENT.lock()?; +pub async fn insert_track(user_id: i32, track: FullTrack, weight: i32) -> Result<(), Error> { + let client = &CLIENT; if track.id.is_none() { println!("{:#?}", track); return Err("failed to load get track information".into()); } print!(" {} ", track.id.clone()?); - client.execute( - "INSERT INTO track (track_code, name, artist, popularity) + client + .execute( + "INSERT INTO track (track_code, name, artist, popularity) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING", - &[ - &(track.id.clone()?), - &track.name, - &track.artists[0].name, - &(track.popularity as i32), - ], - )?; - let track_id: i32 = client.query( - "SELECT track_id FROM track where track_code = $1;", - &[&(track.id?)], - )?[0] + &[ + &(track.id.clone()?), + &track.name, + &track.artists[0].name, + &(track.popularity as i32), + ], + ) + .await?; + let track_id: i32 = client + .query( + "SELECT track_id FROM track where track_code = $1;", + &[&(track.id?)], + ) + .await?[0] .get(0); println!("uid: {} tid: {}", user_id, track_id); - client.execute( - " + client + .execute( + " INSERT INTO user_track_raw (track_id, user_id, count) VALUES ($1, $2, $3) ON CONFLICT ON CONSTRAINT track_user_pkey DO NOTHING; ", - &[&track_id, &user_id, &0], - )?; - client.execute( - "UPDATE user_track SET count = count + $3 WHERE track_id = $1 AND user_id = $2;", - &[&track_id, &user_id, &weight], - )?; + &[&track_id, &user_id, &0], + ) + .await?; + client + .execute( + "UPDATE user_track SET count = count + $3 WHERE track_id = $1 AND user_id = $2;", + &[&track_id, &user_id, &weight], + ) + .await?; Ok(()) } -pub fn insert_user(name: &str, lobby: &str) -> Result<i32, Error> { - let mut client = CLIENT.lock()?; - let x = get_lid(lobby, &mut *client)?; +pub async fn insert_user(name: &str, lobby: &str) -> Result<i32, Error> { + let client = &CLIENT; + let x = get_lid(lobby, &*client).await?; client.execute( "INSERT INTO suser (user_name, lobby_id) VALUES ($1, $2) ON CONFLICT (user_name, lobby_id) DO NOTHING;", &[&name, &x], - )?; - let db_user_id = get_uid(name, lobby, &mut *client)?; - client.execute("DELETE FROM user_track WHERE user_id = $1;", &[&db_user_id])?; + ).await?; + let db_user_id = get_uid(name, lobby, &*client).await?; + client + .execute("DELETE FROM user_track WHERE user_id = $1;", &[&db_user_id]) + .await?; Ok(db_user_id) } -fn get_lid(lobby: &str, client: &mut postgres::Client) -> Result<i32, Error> { +async fn get_lid(lobby: &str, client: &Client) -> Result<i32, Error> { let x: i32 = client - .query_one("SELECT lobby_id FROM lobby WHERE token = $1;", &[&lobby])? + .query_one("SELECT lobby_id FROM lobby WHERE token = $1;", &[&lobby]) + .await? .get(0); Ok(x) } -fn get_uid(name: &str, lobby: &str, client: &mut postgres::Client) -> Result<i32, Error> { +async fn get_uid(name: &str, lobby: &str, client: &Client) -> Result<i32, Error> { let x: i32 = client - .query_one("SELECT user_id FROM suser JOIN lobby USING (lobby_id) WHERE user_name = $1 AND token = $2;", &[&name, &lobby])? + .query_one("SELECT user_id FROM suser JOIN lobby USING (lobby_id) WHERE user_name = $1 AND token = $2;", &[&name, &lobby]).await? .get(0); Ok(x) } -pub fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> { - let mut client = CLIENT.lock()?; +pub async fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> { + let client = &CLIENT; let mut songs = String::new(); let names: Vec<String> = names.iter().map(|x| x.to_string()).collect(); - let users = names - .iter() - .fold(String::new(), |a, name| format!("{}, ({})", a, name)); - let users: String = users.chars().skip(2).collect(); - for row in client.query( + //let users: Vec<&str> = names.iter().map(|x| x.as_str()).collect(); + //println!("users to match: {:?}", users); + let placeholders = + (2..=(names.len() + 1)).fold(String::new(), |p, a| format!("{}, (${})", p, a)); + let placeholders: String = placeholders.chars().skip(2).collect(); + let query = format!( " WITH users AS ( SELECT * - FROM ( VALUES $1 AS _ (user_id) + FROM ( VALUES {} ) AS _ (user_id) ) SELECT track_id, name, artist FROM track @@ -146,15 +156,23 @@ pub fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> { JOIN suser USING (user_id) JOIN track USING (track_id) JOIN lobby USING (lobby_id) - WHERE suser.user_name IN (SELECT * FROM users) AND track = $2 + WHERE suser.user_name IN (SELECT * FROM users) AND token = $1 GROUP BY track_id HAVING COUNT(track_id) = (SELECT COUNT(*) FROM users) ORDER BY SUM(score) DESC ) AS _ USING (track_id) ; ", - &[&users, &lobby], - )? { + placeholders + ); + let mut values = vec![lobby]; + values.extend(names); + println!("{:?}", values); + + use futures::{pin_mut, TryStreamExt}; + let mut it = client.query_raw(query.as_str(), values).await?; + pin_mut!(it); + for row in it.try_next().await? { let name: String = row.get(1); let artist: String = row.get(2); songs = format!("{}{} by {}\n", songs, name, artist); @@ -162,33 +180,39 @@ pub fn match_users(lobby: String, names: &[&str]) -> Result<String, Error> { Ok(songs) } -pub fn get_users(lobby: &str) -> Result<String, Error> { - let mut client = CLIENT.lock()?; +pub async fn get_users(lobby: &str) -> Result<String, Error> { + let client = &CLIENT; let mut users = String::new(); - for row in client.query( - "SELECT user_name FROM suser JOIN lobby USING (lobby_id) WHERE token = $1", - &[&lobby], - )? { + for row in client + .query( + "SELECT user_name FROM suser JOIN lobby USING (lobby_id) WHERE token = $1", + &[&lobby], + ) + .await? + { let user: String = row.get(0); users = format!("{}{}\n", users, user); } Ok(users) } -pub fn create_lobby(name: &str) -> Result<String, Error> { - let mut client = CLIENT.lock()?; +pub async fn create_lobby(name: &str) -> Result<String, Error> { + let client = &CLIENT; let mut token = String::new(); while token.is_empty() || client .query_one("SELECT lobby_id FROM lobby WHERE token = $1", &[&token]) + .await .is_ok() { let rand: [u8; 20] = rand::thread_rng().gen(); token = base64::encode_config(&rand, base64::URL_SAFE); } - client.execute( - "INSERT INTO lobby (token, lobby_name) VALUES ($1, $2);", - &[&token, &name], - )?; + client + .execute( + "INSERT INTO lobby (token, lobby_name) VALUES ($1, $2);", + &[&token, &name], + ) + .await?; Ok(token) } diff --git a/src/errors.rs b/src/errors.rs index 544c391..d935c49 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -1,17 +1,17 @@ use rocket::http::ContentType; use rocket::request::Request; use rocket::response::{self, Responder, Response}; -use rspotify::client::ApiError; -use rspotify::oauth2::SpotifyOAuth; +use rspotify::client::ClientError; +use rspotify::client::Spotify; use std::collections::HashMap; use std::io::Cursor; use std::sync::{MutexGuard, PoisonError}; -use tokio_postgres::error::Error as DbError; +use tokio_postgres::Error as DbError; #[derive(Debug)] pub enum Error { Postgres(DbError), - Spotify(ApiError), + Spotify(ClientError), Misc(String), } impl From<DbError> for Error { @@ -19,8 +19,8 @@ impl From<DbError> for Error { Error::Postgres(error) } } -impl From<ApiError> for Error { - fn from(error: ApiError) -> Self { +impl From<ClientError> for Error { + fn from(error: ClientError) -> Self { Error::Spotify(error) } } @@ -34,13 +34,13 @@ impl From<String> for Error { Error::Misc(error) } } -impl<'a> From<PoisonError<MutexGuard<'a, HashMap<String, SpotifyOAuth>>>> for Error { - fn from(error: PoisonError<MutexGuard<'a, HashMap<String, SpotifyOAuth>>>) -> Self { +impl<'a> From<PoisonError<MutexGuard<'a, HashMap<String, Spotify>>>> for Error { + fn from(error: PoisonError<MutexGuard<'a, HashMap<String, Spotify>>>) -> Self { Error::Misc(format!("failed to lock the client mutex: {:?}", error)) } } -impl<'a> From<PoisonError<MutexGuard<'a, postgres::Client>>> for Error { - fn from(error: PoisonError<MutexGuard<'a, postgres::Client>>) -> Self { +impl<'a> From<PoisonError<MutexGuard<'a, tokio_postgres::Client>>> for Error { + fn from(error: PoisonError<MutexGuard<'a, tokio_postgres::Client>>) -> Self { Error::Misc(format!("failed to lock the client mutex: {:?}", error)) } } diff --git a/src/main.rs b/src/main.rs index eb5af4c..0ff4827 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,22 +10,39 @@ mod errors; mod serve; mod spotify; -#[rocket::main] -async fn main() { - database::initialize_db().expect("failed to initialize_db"); +use tokio::runtime::Runtime; + +fn main() { + let runtime = Runtime::new().unwrap(); + let (client, connection) = runtime + .block_on(tokio_postgres::connect( + "host=127.0.0.1 user=spotify_intersect password=example dbname=track_db", + tokio_postgres::NoTls, + )) + .expect("failed to connect to database"); + unsafe { + database::_client = Some(client); + } + runtime.spawn(async { connection.await.unwrap() }); + + runtime + .block_on(database::initialize_db()) + .expect("failed to initialize_db"); println!("connected with db"); - rocket::ignite() - .mount( - "/", - routes![ - serve::token, - serve::get_tracks, - serve::match_users, - serve::get_users, - serve::create_lobby - ], + runtime + .block_on( + rocket::ignite() + .mount( + "/", + routes![ + serve::token, + serve::get_tracks, + serve::match_users, + serve::get_users, + serve::create_lobby + ], + ) + .launch(), ) - .launch() - .await .unwrap(); } diff --git a/src/serve.rs b/src/serve.rs index ee0e5ff..ab07607 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -5,8 +5,8 @@ use rocket::response::Redirect; #[get("/callback/<name>/<lobby>/<url>")] pub async fn get_tracks(name: String, lobby: String, url: String) -> Result<(), Error> { - let (spotify_uid, spotify_client) = spotify::auth_user(name.as_ref(), url).await?; - let uid = database::insert_user(spotify_uid.as_ref(), lobby.as_ref())?; + let (spotify_uid, spotify_client) = spotify::auth_user(name.as_ref(), url.as_str()).await?; + let uid = database::insert_user(spotify_uid.as_ref(), lobby.as_ref()).await?; spotify::load_profile(uid, spotify_uid.as_ref(), spotify_client).await } @@ -16,20 +16,20 @@ pub fn token(name: String) -> Result<Redirect, Error> { } #[get("/match/<lobby>/<names>")] -pub fn match_users(lobby: String, names: String) -> Result<String, Error> { +pub async fn match_users(lobby: String, names: String) -> Result<String, Error> { //let bytes = base64::decode_config(names, base64::URL_SAFE).unwrap(); //let names = String::from_utf8(bytes).unwrap(); let names: Vec<&str> = names.split(',').collect(); - database::match_users(lobby, names.as_slice()) + database::match_users(lobby, names.as_slice()).await } #[get("/users/<lobby>")] -pub fn get_users(lobby: String) -> Result<String, Error> { - database::get_users(lobby.as_str()) +pub async fn get_users(lobby: String) -> Result<String, Error> { + database::get_users(lobby.as_str()).await } //#[post("/lobby", format = "application/json", data = "<name>")] #[post("/lobby/<name>")] -pub fn create_lobby(name: String) -> Result<String, Error> { - database::create_lobby(name.as_str()) +pub async fn create_lobby(name: String) -> Result<String, Error> { + database::create_lobby(name.as_str()).await } diff --git a/src/spotify.rs b/src/spotify.rs index 0e18c43..ea05b14 100644 --- a/src/spotify.rs +++ b/src/spotify.rs @@ -1,19 +1,16 @@ use crate::database; use crate::errors::Error; use lazy_static::lazy_static; -use rspotify::client::{ApiError, Spotify}; -use rspotify::model::page::Page; +use rand::{Rng, RngCore}; +use rspotify::client::{ClientError, Spotify}; +use rspotify::model::enums::TimeRange; use rspotify::model::playlist::*; use rspotify::model::track::*; -use rspotify::oauth2::{SpotifyClientCredentials, SpotifyOAuth}; -use rspotify::senum::TimeRange; -use rspotify::util::process_token; use std::collections::HashMap; use std::sync::{Arc, Mutex}; lazy_static! { - static ref CACHE: Arc<Mutex<HashMap<String, SpotifyOAuth>>> = - Arc::new(Mutex::new(HashMap::new())); + static ref CACHE: Arc<Mutex<HashMap<String, Spotify>>> = Arc::new(Mutex::new(HashMap::new())); } static CHUNK_SIZE: u32 = 50; @@ -23,8 +20,7 @@ macro_rules! get_items { $index = 0; let mut result: Vec<$t> = Vec::new(); loop { - let res: Result<Page<$t>, failure::Error> = $spotify_call.await; - match res { + match $spotify_call.await { Ok(mut page) => { $index += CHUNK_SIZE; if page.items.is_empty() { @@ -32,8 +28,8 @@ macro_rules! get_items { } result.append(&mut page.items); } - Err(e) => match e.downcast::<ApiError>() { - Ok(ApiError::RateLimited(x)) => { + Err(e) => match e { + ClientError::RateLimited(x) => { std::thread::sleep(std::time::Duration::from_secs(x.unwrap_or(5) as u64)) } @@ -66,22 +62,23 @@ pub async fn load_profile(db_uid: i32, spotify_uid: &str, spotify: Spotify) -> R spotify.current_user_top_tracks(CHUNK_SIZE, index, TimeRange::MediumTerm) ); for track in library { - if let Err(e) = database::insert_track(db_uid, track.track, 5) { + if let Err(e) = database::insert_track(db_uid, track.track, 5).await { println!("failed to load track to db: {:?}", e) }; } for (pos, track) in top_tracks.iter().enumerate() { let weight = ((50.0 - pos as f64) / 50.0 * 10.0).floor() as i32; - if let Err(e) = database::insert_track(db_uid, track.clone(), 5 + weight) { + if let Err(e) = database::insert_track(db_uid, track.clone(), 5 + weight).await { println!("failed to load track to db: {:?}", e) }; } + let playlists = playlists.iter().map(|x| x.clone()); for playlist in playlists { let tracks = get_items!( - PlaylistTrack, + PlaylistItem, index, - spotify.user_playlist_tracks( - spotify_uid.as_ref(), + spotify.playlist_tracks( + //spotify_uid.as_ref(), &playlist.id, None, CHUNK_SIZE, @@ -89,8 +86,9 @@ pub async fn load_profile(db_uid: i32, spotify_uid: &str, spotify: Spotify) -> R None, ) ); - for track in tracks.iter().map(|x| x.track.clone()).flatten() { - if let Err(e) = database::insert_track(db_uid, track, 1) { + let tracks: Vec<FullTrack> = tracks.iter().map(|x| x.track.clone()).flatten().collect(); + for track in tracks { + if let Err(e) = database::insert_track(db_uid, track, 1).await { println!("failed to load track to db: {:?}", e) }; } @@ -98,21 +96,37 @@ pub async fn load_profile(db_uid: i32, spotify_uid: &str, spotify: Spotify) -> R Ok(()) } -pub async fn auth_user(name: &str, url: String) -> Result<(String, Spotify), Error> { - let mut oauth = { - let mut guard = CACHE.lock()?; +use rspotify::client::SpotifyBuilder; +use rspotify::oauth2::{CredentialsBuilder, OAuthBuilder}; + +/// Generate `length` random chars +fn generate_random_uuid(length: usize) -> String { + let alphanum: &[u8] = + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".as_bytes(); + let mut buf = vec![0u8; length]; + rand::thread_rng().fill_bytes(buf.as_mut_slice()); + let range = alphanum.len(); + + buf.iter() + .map(|byte| alphanum[*byte as usize % range] as char) + .collect() +} + +pub async fn auth_user(name: &str, code: &str) -> Result<(String, Spotify), Error> { + let mut spotify = { + let mut guard = (*CACHE).lock()?; guard.remove(name)? }; - println!("auth: {:?} url: {}", oauth, url); - let mut token_string = format!("?code={}", url); - let token_info = process_token(&mut oauth, &mut token_string); - let client_credential = SpotifyClientCredentials::default() + println!("auth: {:?} url: {}", name, code); + spotify.request_user_token(code).await?; + //let token_info = process_token(&mut oauth, &mut token_string); + /*let client_credential = SpotifyClientCredentials::default() .token_info(token_info.await?) .build(); let spotify = Spotify::default() - .client_credentials_manager(client_credential) - .build(); + .client_credentials_manager(client_credential) + .build();*/ let user_id = spotify .current_user() .await @@ -122,13 +136,32 @@ pub async fn auth_user(name: &str, url: String) -> Result<(String, Spotify), Err } pub fn token(name: String) -> Result<String, Error> { - let state = rspotify::util::generate_random_string(16); + let scope = "playlist-read-private playlist-read-collaborative user-read-private user-follow-read user-library-read"; + + let oauth = OAuthBuilder::from_env() + .scope(scope.split_whitespace().map(|x| x.to_owned()).collect()) + .build() + .unwrap(); + let creds = CredentialsBuilder::from_env().build().unwrap(); + + let spotify = SpotifyBuilder::default() + .credentials(creds) + .oauth(oauth) + .build() + .unwrap(); + + let auth_url = spotify.get_authorize_url(false).unwrap(); + + //let token = spotify.token.as_ref().unwrap(); + + /*let state = rspotify::util::generate_random_string(16); let oauth = SpotifyOAuth::default(); let oauth = oauth .scope("playlist-read-private, playlist-read-collaborative, user-read-private, user-follow-read, user-library-read") .build(); - let auth_url = oauth.get_authorize_url(Some(&state), None); - let mut guard = CACHE.lock()?; - guard.insert(name, oauth); + */ + //let auth_url = oauth.get_authorize_url(Some(&state), None); + let mut guard = (*CACHE).lock()?; + guard.insert(name, spotify); Ok(auth_url) } |