use crate::db::AppState; use crate::types::JsonMessage; use axum::{ extract::{Query, State}, http::{HeaderMap, StatusCode}, response::IntoResponse, Json, }; use headers::{Cookie, HeaderMapExt}; use redis::AsyncCommands; use serde::{Deserialize, Serialize}; use sqlx::Row; use std::{collections::HashMap, env}; use uuid::Uuid; #[derive(Deserialize)] pub struct CallbackQuery { code: String, state: Option, } #[derive(Deserialize, Serialize)] pub struct DiscordUser { pub id: String, pub username: String, pub discriminator: String, pub avatar: Option, } #[derive(Serialize)] pub struct AuthResponse { user: DiscordUser, session: String, } pub async fn get_user_from_session( headers: &HeaderMap, state: &AppState, ) -> Result { let Some(cookie_header) = headers.typed_get::() else { return Err(( StatusCode::UNAUTHORIZED, Json(JsonMessage { message: "Missing session cookie".into(), }), )); }; let Some(session_id) = cookie_header.get("session") else { return Err(( StatusCode::UNAUTHORIZED, Json(JsonMessage { message: "Missing session ID".into(), }), )); }; let mut redis = state.redis.clone(); let key = format!("session:{}", session_id); let Ok(json) = redis.get::<_, String>(&key).await else { return Err(( StatusCode::UNAUTHORIZED, Json(JsonMessage { message: "Session not found".into(), }), )); }; let Ok(user) = serde_json::from_str::(&json) else { return Err(( StatusCode::UNAUTHORIZED, Json(JsonMessage { message: "Invalid user session".into(), }), )); }; Ok(user) } pub async fn start_oauth(Query(params): Query>) -> impl IntoResponse { let client_id = env::var("CLIENT_ID").unwrap_or_default(); let redirect_uri = env::var("REDIRECT_URI").unwrap_or_default(); let mut url = format!( "https://discord.com/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify", client_id, redirect_uri ); if let Some(redirect) = params.get("redirect") { url.push_str(&format!("&state={}", urlencoding::encode(redirect))); } (StatusCode::FOUND, [(axum::http::header::LOCATION, url)]).into_response() } pub async fn handle_callback( State(state): State, Query(query): Query, ) -> impl IntoResponse { let client_id = env::var("CLIENT_ID").unwrap(); let client_secret = env::var("CLIENT_SECRET").unwrap(); let redirect_uri = env::var("REDIRECT_URI").unwrap(); let form = [ ("client_id", client_id.as_str()), ("client_secret", client_secret.as_str()), ("grant_type", "authorization_code"), ("code", &query.code), ("redirect_uri", redirect_uri.as_str()), ]; let token_res = reqwest::Client::new() .post("https://discord.com/api/oauth2/token") .form(&form) .send() .await; let Ok(res) = token_res else { return ( StatusCode::BAD_REQUEST, Json(JsonMessage { message: "Failed to exchange token".into(), }), ) .into_response(); }; let Ok(token_json) = res.json::().await else { return ( StatusCode::INTERNAL_SERVER_ERROR, Json(JsonMessage { message: "Invalid token response".into(), }), ) .into_response(); }; let Some(access_token) = token_json["access_token"].as_str() else { return ( StatusCode::UNAUTHORIZED, Json(JsonMessage { message: "Access token not found".into(), }), ) .into_response(); }; let user_res = reqwest::Client::new() .get("https://discord.com/api/users/@me") .header("Authorization", format!("Bearer {}", access_token)) .send() .await; let Ok(user_res) = user_res else { return ( StatusCode::BAD_REQUEST, Json(JsonMessage { message: "Failed to fetch user".into(), }), ) .into_response(); }; let Ok(user) = user_res.json::().await else { return ( StatusCode::INTERNAL_SERVER_ERROR, Json(JsonMessage { message: "Failed to parse user".into(), }), ) .into_response(); }; let session_id = Uuid::now_v7().to_string(); let mut redis = state.redis.clone(); let _ = redis .set_ex::<_, _, ()>( format!("session:{}", session_id), serde_json::to_string(&user).unwrap(), 3600, ) .await; let redirect_target = match &query.state { Some(s) => urlencoding::decode(s) .map(|s| s.into_owned()) .unwrap_or("/".to_string()), None => "/".to_string(), }; let mut headers = HeaderMap::new(); headers.insert( "Set-Cookie", format!( "session={}; Max-Age=3600; Path=/; SameSite=None; Secure; HttpOnly", session_id ) .parse() .unwrap(), ); headers.insert( axum::http::header::LOCATION, redirect_target.parse().unwrap(), ); (StatusCode::FOUND, headers).into_response() } pub async fn me(State(state): State, headers: HeaderMap) -> impl IntoResponse { match get_user_from_session(&headers, &state).await { Ok(user) => { let result = sqlx::query("SELECT timezone FROM timezones WHERE user_id = $1") .bind(&user.id) .fetch_optional(&state.db) .await; match result { Ok(Some(row)) => { let timezone: String = row.get("timezone"); ( StatusCode::OK, Json(serde_json::json!({ "user": user, "timezone": timezone })), ) .into_response() } Ok(None) => ( StatusCode::OK, Json(serde_json::json!({ "user": user, "timezone": null })), ) .into_response(), Err(_) => ( StatusCode::INTERNAL_SERVER_ERROR, Json(JsonMessage { message: "Failed to fetch timezone".into(), }), ) .into_response(), } } Err(err) => err.into_response(), } }