diff --git a/src/database/bot.rs b/src/database/bot.rs index c1d0c8f..1ae6615 100644 --- a/src/database/bot.rs +++ b/src/database/bot.rs @@ -1,5 +1,5 @@ use sqlx::{PgPool, query, query_as}; -use crate::models::bot::{Bot, BotPresence, BotActivity}; +use crate::models::bot::{DbBot, BotPresence, BotActivity}; const BOT_ID: i32 = 1; @@ -11,8 +11,8 @@ pub async fn init(db: &PgPool) -> Result<(), sqlx::Error> { Ok(()) } -pub async fn get(db: &PgPool) -> Result, sqlx::Error> { - let bot = query_as::<_, Bot>( +pub async fn get(db: &PgPool) -> Result, sqlx::Error> { + let bot = query_as::<_, DbBot>( "SELECT * FROM bots WHERE id = $1", ) .bind(BOT_ID) diff --git a/src/database/guild.rs b/src/database/guild.rs index 2c5f8fa..9e2012c 100644 --- a/src/database/guild.rs +++ b/src/database/guild.rs @@ -1,9 +1,10 @@ use sqlx::{ PgPool, + query, query_as, query_scalar, }; -use crate::models::Guild; +use crate::models::DbGuild; pub enum LogChannel { Bot, @@ -63,8 +64,8 @@ pub async fn delete(db: &PgPool, guild_id: &str) -> Result<(), sqlx::Error> { Ok(()) } -pub async fn get(db: &PgPool, guild_id: &str) -> Result, sqlx::Error> { - let guild: Option = query_as::<_, Guild>( +pub async fn get(db: &PgPool, guild_id: &str) -> Result, sqlx::Error> { + let guild: Option = query_as::<_, DbGuild>( "SELECT * FROM guilds WHERE guild_id = $1", ) .bind(guild_id) @@ -110,3 +111,10 @@ pub async fn set_protect(db: &PgPool, user_id: &str, asked: Protect, value: &str .await?; Ok(()) } + +pub async fn get_or_create(db: &PgPool, guild_id: &str) -> Result { + create(db, guild_id).await?; + get(db, guild_id) + .await? + .ok_or_else(|| sqlx::Error::RowNotFound) +} diff --git a/src/database/guild_user.rs b/src/database/guild_user.rs index 49164c2..0a7e799 100644 --- a/src/database/guild_user.rs +++ b/src/database/guild_user.rs @@ -1,12 +1,12 @@ use sqlx::{PgPool, query, query_as, query_scalar}; -use crate::models::GuildUser; +use crate::models::guild_user::DbGuildUser; pub async fn get( db: &PgPool, user_id: &str, guild_id: &str, -) -> Result, sqlx::Error> { - let guild_user = query_as::<_, guild_userildUser>( +) -> Result, sqlx::Error> { + let guild_user = query_as::<_, DbGuildUser>( "SELECT * FROM guild_users WHERE user_id = $1 AND guild_id = $2", ) .bind(user_id) @@ -146,8 +146,8 @@ pub async fn set_wl( pub async fn get_all_wl( db: &PgPool, guild_id: &str, -) -> Result, sqlx::Error> { - let users = query_as::<_, GuildUser>( +) -> Result, sqlx::Error> { + let users = query_as::<_, DbGuildUser>( "SELECT * FROM guild_users \ WHERE guild_id = $1 AND is_wl_user = true", ) @@ -211,8 +211,8 @@ pub async fn leaderboard_xp( db: &PgPool, guild_id: &str, limit: i64, -) -> Result, sqlx::Error> { - let users = query_as::<_, GuildUser>( +) -> Result, sqlx::Error> { + let users = query_as::<_, DbGuildUser>( "SELECT * FROM guild_users \ WHERE guild_id = $1 \ ORDER BY xp DESC \ @@ -229,8 +229,8 @@ pub async fn leaderboard_invitations( db: &PgPool, guild_id: &str, limit: i64, -) -> Result, sqlx::Error> { - let users = query_as::<_, GuildUser>( +) -> Result, sqlx::Error> { + let users = query_as::<_, DbGuildUser>( "SELECT * FROM guild_users \ WHERE guild_id = $1 \ ORDER BY invitation_count DESC \ @@ -242,4 +242,9 @@ pub async fn leaderboard_invitations( .await?; Ok(users) } - +pub async fn get_or_create(db: &PgPool, user_id: &str, guild_id: &str) -> Result { + create(db, user_id, guild_id).await?; + get(db, user_id, guild_id) + .await? + .ok_or_else(|| sqlx::Error::RowNotFound) +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 6706249..cfb1275 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1 +1,10 @@ include!("./mod_gen.rs"); + +use sqlx::{Pool, Postgres}; +use serenity::prelude::TypeMapKey; + +pub struct DbPool; + +impl TypeMapKey for DbPool { + type Value = Pool; +} diff --git a/src/database/user.rs b/src/database/user.rs index 009d4b4..5986878 100644 --- a/src/database/user.rs +++ b/src/database/user.rs @@ -3,7 +3,7 @@ use sqlx::{ query, query_as }; -use crate::models::User; +use crate::models::DbUser; /// Adding the user (if exist do nothing) /// @@ -26,7 +26,7 @@ pub async fn create(db: &PgPool, user_id: &str) -> Result<(), sqlx::Error> { /// Take the database information of a user /// /// # Returns -/// [`User`] or `None` if the user doesn't exist +/// [`DbUser`] or `None` if the user doesn't exist /// /// # Arguments /// @@ -36,9 +36,9 @@ pub async fn create(db: &PgPool, user_id: &str) -> Result<(), sqlx::Error> { /// # Errors /// /// Returns `sqlx::Error` if the query fails. -pub async fn get(db: &PgPool, user_id: &str) -> Result, sqlx::Error> { - let user: Option = query_as::<_, User>( - "SELECT user_id, is_owner, is_buyer, is_dev FROM users WHERE user_id = $1", +pub async fn get(db: &PgPool, user_id: &str) -> Result, sqlx::Error> { + let user: Option = query_as::<_, DbUser>( + "SELECT * FROM users WHERE user_id = $1", ) .bind(user_id) .fetch_optional(db) @@ -85,3 +85,10 @@ pub async fn set_buyer(db: &PgPool, user_id: &str, value: bool) -> Result<(), sq .await?; Ok(()) } + +pub async fn get_or_create(db: &PgPool, user_id: &str) -> Result { + create(db, user_id).await?; + get(db, user_id) + .await? + .ok_or_else(|| sqlx::Error::RowNotFound) +}