# Git Commit Message

refactor: add production features and improve architecture

- Add structured configuration with validation
- Implement Redis connection pooling
- Add database migrations system
- Change API methods: GET /set → POST /set, GET /delete → DELETE /delete
- Add health check endpoint
- Add graceful shutdown and structured logging
- Update frontend for new API methods
- Add open source attribution
This commit is contained in:
creations 2025-06-04 07:56:15 -04:00
parent ad6c9b7095
commit 6bfd298455
Signed by: creations
GPG key ID: 8F553AA4320FC711
21 changed files with 842 additions and 3124 deletions

View file

@ -1,2 +1,8 @@
postgres-data
dragonfly-data
target/
postgres-data/
dragonfly-data/
.env
.git/
.gitignore
README.md
*.log

21
.env.example Normal file
View file

@ -0,0 +1,21 @@
# Server Configuration
HOST=0.0.0.0
PORT=3000
# Database Configuration
DATABASE_URL=postgres://postgres:postgres@postgres:5432/postgres
DB_MAX_CONNECTIONS=10
DB_CONNECT_TIMEOUT=30
# Redis Configuration
REDIS_URL=redis://dragonfly:6379
REDIS_POOL_SIZE=5
REDIS_CONNECT_TIMEOUT=10
# Discord OAuth Configuration
CLIENT_ID=your_discord_client_id
CLIENT_SECRET=your_discord_client_secret
REDIRECT_URI=https://your.domain/auth/discord/callback
# Logging (optional)
RUST_LOG=info,timezone_db=debug

1
.gitignore vendored
View file

@ -3,3 +3,4 @@
dragonfly-data
postgres-data
.env
Cargo.lock

3019
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -5,18 +5,25 @@ edition = "2021"
[dependencies]
axum = "0.8.4"
tokio = { version = "1", features = ["full"] }
sqlx = { version = "0.8.6", features = ["postgres", "runtime-tokio", "macros"] }
tokio = { version = "1", features = ["full", "signal"] }
sqlx = { version = "0.8.6", features = [
"postgres",
"runtime-tokio",
"macros",
"chrono",
] }
redis = { version = "0.31", features = ["tokio-comp", "aio"] }
uuid = { version = "1", features = ["v7"] }
dotenvy = "0.15"
tracing = "0.1"
tracing-subscriber = "0.3"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1.0"
reqwest = { version = "0.12", features = ["json", "gzip"] }
tower-http = { version = "0.6.4", features = ["cors", "fs"] }
headers = "0.4.0"
chrono-tz = "0.10.3"
chrono = { version = "0.4", features = ["serde"] }
tower = "0.5.2"
urlencoding = "2.1.3"
thiserror = "2.0.12"

View file

@ -5,6 +5,7 @@ WORKDIR /app
COPY Cargo.toml Cargo.lock ./
COPY src ./src
COPY public ./public
COPY migrations ./migrations
RUN cargo build --release
@ -16,7 +17,6 @@ WORKDIR /app
COPY --from=builder /app/target/release/timezone-db /usr/local/bin/app
COPY --from=builder /app/public ./public
ENV RUST_LOG=info
COPY --from=builder /app/migrations ./migrations
CMD ["/usr/local/bin/app"]

View file

@ -7,8 +7,10 @@ A simple Rust-powered API service for managing and retrieving user timezones.
- Store user timezones via `/set` endpoint (requires Discord OAuth)
- Retrieve timezones by user ID via `/get`
- List all saved timezones
- Cookie-based session handling using Redis
- Cookie-based session handling using Redis connection pooling
- Built-in CORS support
- Structured configuration with validation
- Graceful shutdown support
- Fully containerized with PostgreSQL and DragonflyDB
## Requirements
@ -21,15 +23,27 @@ A simple Rust-powered API service for managing and retrieving user timezones.
Create a `.env` file with the following:
```env
# Server Configuration
HOST=0.0.0.0
PORT=3000
# Database Configuration
DATABASE_URL=postgres://postgres:postgres@postgres:5432/postgres
REDIS_URL=redis://dragonfly:6379
DB_MAX_CONNECTIONS=10
DB_CONNECT_TIMEOUT=30
# Redis Configuration
REDIS_URL=redis://dragonfly:6379
REDIS_POOL_SIZE=5
REDIS_CONNECT_TIMEOUT=10
# Discord OAuth Configuration
CLIENT_ID=your_discord_client_id
CLIENT_SECRET=your_discord_client_secret
REDIRECT_URI=https://your.domain/auth/discord/callback
# Logging (optional)
RUST_LOG=info,timezone_db=debug
```
## Setup
@ -40,16 +54,34 @@ REDIRECT_URI=https://your.domain/auth/discord/callback
docker compose up --build
```
### Run Manually
```bash
# Make sure PostgreSQL and Redis are running
cargo run
```
## API Endpoints
### `GET /get?id=<discord_user_id>`
Returns stored timezone and username for the given user ID.
**Response:**
```json
{
"user": {
"id": "123456789",
"username": "username"
},
"timezone": "America/New_York"
}
```
### `POST /set`
Stores timezone for the authenticated user. Requires Discord OAuth session.
Body: `application/x-www-form-urlencoded` with `timezone=<iana_timezone>`
**Body:** `application/x-www-form-urlencoded` with `timezone=<iana_timezone>`
### `DELETE /delete`
@ -59,14 +91,40 @@ Deletes the authenticated user's timezone entry. Requires Discord OAuth session.
Returns a JSON object of all stored timezones by user ID.
**Response:**
```json
{
"123456789": {
"username": "user1",
"timezone": "America/New_York"
},
"987654321": {
"username": "user2",
"timezone": "Europe/London"
}
}
```
### `GET /me`
Returns Discord profile info for the current session.
Returns Discord profile info and timezone for the current session.
### `GET /auth/discord`
Starts OAuth2 authentication flow.
Starts OAuth2 authentication flow. Supports optional `?redirect=` parameter.
### `GET /auth/discord/callback`
Handles OAuth2 redirect and sets a session cookie.
## Configuration
The application uses structured configuration with validation. All required environment variables must be provided, and the app will exit with helpful error messages if configuration is invalid.
### Optional Configuration Variables
- `DB_MAX_CONNECTIONS`: Maximum database connections (default: 10)
- `DB_CONNECT_TIMEOUT`: Database connection timeout in seconds (default: 30)
- `REDIS_POOL_SIZE`: Redis connection pool size (default: 5)
- `REDIS_CONNECT_TIMEOUT`: Redis connection timeout in seconds (default: 10)
- `RUST_LOG`: Logging level configuration

View file

@ -15,6 +15,12 @@ services:
condition: service_started
networks:
- timezoneDB-network
healthcheck:
test: ["CMD", "curl", "-f", "http://${HOST:-localhost}:${PORT:-3000}/health"]
interval: 30s
timeout: 3s
retries: 3
start_period: 10s
postgres:
image: postgres:16
@ -23,6 +29,7 @@ services:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
PGUSER: postgres
volumes:
- ./postgres-data:/var/lib/postgresql/data
networks:
@ -45,4 +52,4 @@ services:
networks:
timezoneDB-network:
driver: bridge
driver: bridge

View file

@ -0,0 +1,9 @@
CREATE TABLE IF NOT EXISTS timezones (
user_id TEXT PRIMARY KEY,
username TEXT NOT NULL,
timezone TEXT NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_timezones_updated_at ON timezones(updated_at);

View file

@ -0,0 +1,5 @@
ALTER TABLE timezones
ADD COLUMN IF NOT EXISTS created_at TIMESTAMPTZ DEFAULT NOW();
ALTER TABLE timezones
ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ DEFAULT NOW();

View file

@ -7,18 +7,19 @@ const setBtn = document.getElementById("set-timezone");
const statusMsg = document.getElementById("status-msg");
const timezones = Intl.supportedValuesOf("timeZone");
timezones.forEach(tz => {
for (const tz of timezones) {
const opt = document.createElement("option");
opt.value = tz;
opt.textContent = tz;
timezoneSelect.appendChild(opt);
});
}
const ts = new TomSelect("#timezone-select", {
create: false,
sorted: true,
searchField: ["text"],
maxOptions: 1000
maxOptions: 1000,
});
async function fetchUserInfo() {
@ -54,7 +55,7 @@ async function fetchUserInfo() {
try {
const res = await fetch("/delete", {
method: "DELETE",
credentials: "include"
credentials: "include",
});
if (!res.ok) throw new Error();
@ -66,7 +67,6 @@ async function fetchUserInfo() {
statusMsg.textContent = "Failed to delete timezone.";
}
});
} catch {
loginSection.classList.remove("hidden");
timezoneSection.classList.add("hidden");
@ -77,17 +77,30 @@ setBtn.addEventListener("click", async () => {
const timezone = ts.getValue();
if (!timezone) return;
setBtn.disabled = true;
setBtn.textContent = "Saving...";
statusMsg.textContent = "";
try {
const res = await fetch("/set", {
method: "POST",
credentials: "include",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: `timezone=${encodeURIComponent(timezone)}`
body: `timezone=${encodeURIComponent(timezone)}`,
});
if (!res.ok) throw new Error();
if (!res.ok) {
const error = await res.json();
throw new Error(error.message || "Failed to update timezone");
}
statusMsg.textContent = "Timezone updated!";
} catch {
statusMsg.textContent = "Failed to update timezone.";
document.getElementById("delete-timezone").classList.remove("hidden");
} catch (error) {
statusMsg.textContent = error.message;
} finally {
setBtn.disabled = false;
setBtn.textContent = "Save";
}
});

9
rustfmt.toml Normal file
View file

@ -0,0 +1,9 @@
max_width = 100
hard_tabs = false
tab_spaces = 4
newline_style = "Unix"
use_small_heuristics = "Default"
reorder_imports = true
reorder_modules = true
remove_nested_parens = true
edition = "2021"

197
src/config.rs Normal file
View file

@ -0,0 +1,197 @@
use std::env;
use std::net::{IpAddr, SocketAddr};
#[derive(Debug, Clone)]
pub struct Config {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub redis: RedisConfig,
pub discord: DiscordConfig,
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub bind_address: SocketAddr,
}
#[derive(Debug, Clone)]
pub struct DatabaseConfig {
pub url: String,
pub max_connections: u32,
pub connect_timeout_seconds: u64,
}
#[derive(Debug, Clone)]
pub struct RedisConfig {
pub url: String,
pub pool_size: u32,
pub connect_timeout_seconds: u64,
}
#[derive(Debug, Clone)]
pub struct DiscordConfig {
pub client_id: String,
pub client_secret: String,
pub redirect_uri: String,
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("Missing required environment variable: {0}")]
MissingEnvVar(String),
#[error("Invalid value for {var}: {value} - {reason}")]
InvalidValue {
var: String,
value: String,
reason: String,
},
#[error("Parse error for {var}: {source}")]
ParseError {
var: String,
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
}
impl Config {
pub fn from_env() -> Result<Self, ConfigError> {
let server = ServerConfig::from_env()?;
let database = DatabaseConfig::from_env()?;
let redis = RedisConfig::from_env()?;
let discord = DiscordConfig::from_env()?;
Ok(Config {
server,
database,
redis,
discord,
})
}
pub fn validate(&self) -> Result<(), ConfigError> {
if !self.discord.redirect_uri.starts_with("http") {
return Err(ConfigError::InvalidValue {
var: "REDIRECT_URI".to_string(),
value: self.discord.redirect_uri.clone(),
reason: "Must start with http:// or https://".to_string(),
});
}
if !self.database.url.starts_with("postgres://")
&& !self.database.url.starts_with("postgresql://")
{
return Err(ConfigError::InvalidValue {
var: "DATABASE_URL".to_string(),
value: "***hidden***".to_string(),
reason: "Must be a valid PostgreSQL connection string".to_string(),
});
}
if !self.redis.url.starts_with("redis://") && !self.redis.url.starts_with("rediss://") {
return Err(ConfigError::InvalidValue {
var: "REDIS_URL".to_string(),
value: "***hidden***".to_string(),
reason: "Must be a valid Redis connection string".to_string(),
});
}
Ok(())
}
}
impl ServerConfig {
pub fn from_env() -> Result<Self, ConfigError> {
let host = get_env_or("HOST", "0.0.0.0")?
.parse::<IpAddr>()
.map_err(|e| ConfigError::ParseError {
var: "HOST".to_string(),
source: Box::new(e),
})?;
let port =
get_env_or("PORT", "3000")?
.parse::<u16>()
.map_err(|e| ConfigError::ParseError {
var: "PORT".to_string(),
source: Box::new(e),
})?;
let bind_address = SocketAddr::new(host, port);
Ok(ServerConfig { bind_address })
}
}
impl DatabaseConfig {
pub fn from_env() -> Result<Self, ConfigError> {
let url = get_env_required("DATABASE_URL")?;
let max_connections = get_env_or("DB_MAX_CONNECTIONS", "10")?
.parse::<u32>()
.map_err(|e| ConfigError::ParseError {
var: "DB_MAX_CONNECTIONS".to_string(),
source: Box::new(e),
})?;
let connect_timeout_seconds = get_env_or("DB_CONNECT_TIMEOUT", "30")?
.parse::<u64>()
.map_err(|e| ConfigError::ParseError {
var: "DB_CONNECT_TIMEOUT".to_string(),
source: Box::new(e),
})?;
Ok(DatabaseConfig {
url,
max_connections,
connect_timeout_seconds,
})
}
}
impl RedisConfig {
pub fn from_env() -> Result<Self, ConfigError> {
let url = get_env_required("REDIS_URL")?;
let pool_size = get_env_or("REDIS_POOL_SIZE", "5")?
.parse::<u32>()
.map_err(|e| ConfigError::ParseError {
var: "REDIS_POOL_SIZE".to_string(),
source: Box::new(e),
})?;
let connect_timeout_seconds = get_env_or("REDIS_CONNECT_TIMEOUT", "10")?
.parse::<u64>()
.map_err(|e| ConfigError::ParseError {
var: "REDIS_CONNECT_TIMEOUT".to_string(),
source: Box::new(e),
})?;
Ok(RedisConfig {
url,
pool_size,
connect_timeout_seconds,
})
}
}
impl DiscordConfig {
pub fn from_env() -> Result<Self, ConfigError> {
let client_id = get_env_required("CLIENT_ID")?;
let client_secret = get_env_required("CLIENT_SECRET")?;
let redirect_uri = get_env_required("REDIRECT_URI")?;
Ok(DiscordConfig {
client_id,
client_secret,
redirect_uri,
})
}
}
fn get_env_required(key: &str) -> Result<String, ConfigError> {
env::var(key).map_err(|_| ConfigError::MissingEnvVar(key.to_string()))
}
fn get_env_or(key: &str, default: &str) -> Result<String, ConfigError> {
Ok(env::var(key).unwrap_or_else(|_| default.to_string()))
}

View file

@ -1,11 +1,23 @@
pub mod postgres;
pub mod redis_helper;
use crate::config::Config;
pub use redis_helper::RedisPool;
pub type Db = sqlx::PgPool;
pub type Redis = redis::aio::MultiplexedConnection;
#[derive(Clone)]
pub struct AppState {
pub db: Db,
pub redis: Redis,
pub redis: RedisPool,
pub config: Config,
}
impl std::fmt::Debug for AppState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AppState")
.field("redis", &self.redis)
.field("config", &self.config)
.finish_non_exhaustive()
}
}

View file

@ -1,27 +1,122 @@
use crate::config::DatabaseConfig;
use sqlx::{postgres::PgPoolOptions, PgPool};
use std::env;
pub async fn connect() -> PgPool {
let db_url = env::var("DATABASE_URL").expect("DATABASE_URL is required");
use std::fs;
use std::path::Path;
use std::time::Duration;
use tracing::{error, info, warn};
pub async fn connect(config: &DatabaseConfig) -> Result<PgPool, sqlx::Error> {
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(&db_url)
.await
.expect("Failed to connect to Postgres");
.max_connections(config.max_connections)
.acquire_timeout(Duration::from_secs(config.connect_timeout_seconds))
.idle_timeout(Some(Duration::from_secs(600)))
.max_lifetime(Some(Duration::from_secs(1800)))
.connect(&config.url)
.await?;
create_migrations_table(&pool).await?;
run_migrations(&pool).await?;
Ok(pool)
}
async fn create_migrations_table(pool: &PgPool) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS timezones (
user_id TEXT PRIMARY KEY,
username TEXT NOT NULL,
timezone TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at TIMESTAMPTZ DEFAULT NOW()
)
"#,
)
.execute(&pool)
.await
.expect("Failed to create timezones table");
.execute(pool)
.await?;
pool
Ok(())
}
async fn run_migrations(pool: &PgPool) -> Result<(), sqlx::Error> {
let migrations_dir = Path::new("migrations");
if !migrations_dir.exists() {
warn!("Migrations directory not found, skipping migrations");
return Ok(());
}
let mut migration_files = Vec::new();
match fs::read_dir(migrations_dir) {
Ok(entries) => {
for entry in entries {
if let Ok(entry) = entry {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("sql") {
if let Some(file_name) = path.file_name().and_then(|s| s.to_str()) {
migration_files.push(file_name.to_string());
}
}
}
}
}
Err(e) => {
error!("Failed to read migrations directory: {}", e);
return Err(sqlx::Error::Io(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Failed to read migrations directory: {}", e),
)));
}
}
migration_files.sort();
for migration_file in migration_files {
let applied = sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = $1)",
)
.bind(&migration_file)
.fetch_one(pool)
.await?;
if applied {
info!("Migration {} already applied, skipping", migration_file);
continue;
}
let migration_path = migrations_dir.join(&migration_file);
let migration_sql = match fs::read_to_string(&migration_path) {
Ok(content) => content,
Err(e) => {
error!("Failed to read migration file {}: {}", migration_file, e);
return Err(sqlx::Error::Io(e));
}
};
info!("Running migration: {}", migration_file);
let mut tx = pool.begin().await?;
let statements: Vec<&str> = migration_sql
.split(';')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
for statement in statements {
if let Err(e) = sqlx::query(statement).execute(&mut *tx).await {
error!("Failed to execute migration {}: {}", migration_file, e);
return Err(e);
}
}
sqlx::query("INSERT INTO schema_migrations (version) VALUES ($1)")
.bind(&migration_file)
.execute(&mut *tx)
.await?;
tx.commit().await?;
info!("Successfully applied migration: {}", migration_file);
}
Ok(())
}

View file

@ -1,12 +1,137 @@
use redis::aio::MultiplexedConnection;
use redis::Client;
use std::env;
use crate::config::RedisConfig;
use redis::{aio::MultiplexedConnection, Client, RedisError};
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
pub async fn connect() -> MultiplexedConnection {
let url = env::var("REDIS_URL").expect("REDIS_URL is required");
let client = Client::open(url).expect("Failed to create Redis client");
client
.get_multiplexed_tokio_connection()
.await
.expect("Failed to connect to Redis")
pub type RedisConnection = MultiplexedConnection;
#[derive(Clone)]
pub struct RedisPool {
connections: Arc<Mutex<VecDeque<RedisConnection>>>,
client: Client,
config: RedisConfig,
}
impl std::fmt::Debug for RedisPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisPool")
.field("config", &self.config)
.field("pool_size", &self.config.pool_size)
.finish()
}
}
impl RedisPool {
pub async fn new(config: RedisConfig) -> Result<Self, RedisError> {
let client = Client::open(config.url.clone())?;
let connections = Arc::new(Mutex::new(VecDeque::new()));
let pool = RedisPool {
connections,
client,
config,
};
pool.initialize_pool().await?;
Ok(pool)
}
async fn initialize_pool(&self) -> Result<(), RedisError> {
let mut connections = self.connections.lock().await;
for _ in 0..self.config.pool_size {
let conn = self.create_connection().await?;
connections.push_back(conn);
}
Ok(())
}
async fn create_connection(&self) -> Result<RedisConnection, RedisError> {
tokio::time::timeout(
Duration::from_secs(self.config.connect_timeout_seconds),
self.client.get_multiplexed_tokio_connection(),
)
.await
.map_err(|_| RedisError::from((redis::ErrorKind::IoError, "Connection timeout")))?
}
pub async fn get_connection(&self) -> Result<PooledConnection, RedisError> {
let mut connections = self.connections.lock().await;
let conn = if let Some(conn) = connections.pop_front() {
conn
} else {
drop(connections);
self.create_connection().await?
};
Ok(PooledConnection {
connection: Some(conn),
pool: self.clone(),
})
}
async fn return_connection(&self, conn: RedisConnection) {
let mut connections = self.connections.lock().await;
if connections.len() < self.config.pool_size as usize {
connections.push_back(conn);
}
}
}
pub struct PooledConnection {
connection: Option<RedisConnection>,
pool: RedisPool,
}
impl PooledConnection {
pub fn as_mut(&mut self) -> &mut RedisConnection {
self.connection
.as_mut()
.expect("Connection already returned to pool")
}
}
impl redis::aio::ConnectionLike for PooledConnection {
fn req_packed_command<'a>(
&'a mut self,
cmd: &'a redis::Cmd,
) -> redis::RedisFuture<'a, redis::Value> {
self.as_mut().req_packed_command(cmd)
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a redis::Pipeline,
offset: usize,
count: usize,
) -> redis::RedisFuture<'a, Vec<redis::Value>> {
self.as_mut().req_packed_commands(cmd, offset, count)
}
fn get_db(&self) -> i64 {
self.connection
.as_ref()
.expect("Connection already returned to pool")
.get_db()
}
}
impl Drop for PooledConnection {
fn drop(&mut self) {
if let Some(conn) = self.connection.take() {
let pool = self.pool.clone();
tokio::spawn(async move {
pool.return_connection(conn).await;
});
}
}
}
pub async fn connect(config: &RedisConfig) -> Result<RedisPool, RedisError> {
RedisPool::new(config.clone()).await
}

View file

@ -1,48 +1,104 @@
use axum::{serve, Router};
use dotenvy::dotenv;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tracing::{error, info};
use tracing_subscriber;
use tracing::{error, info, warn};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
mod config;
mod db;
mod middleware;
mod routes;
mod types;
use config::Config;
use db::{postgres, redis_helper, AppState};
use middleware::cors::DynamicCors;
#[tokio::main]
async fn main() {
dotenv().ok();
tracing_subscriber::fmt::init();
let db = postgres::connect().await;
let redis = redis_helper::connect().await;
let state = AppState { db, redis };
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let config = match Config::from_env() {
Ok(config) => {
if let Err(e) = config.validate() {
error!("Configuration validation failed: {}", e);
std::process::exit(1);
}
config
}
Err(e) => {
error!("Failed to load configuration: {}", e);
std::process::exit(1);
}
};
info!("Starting timezone-db server");
info!("Server will bind to: {}", config.server.bind_address);
let db = match postgres::connect(&config.database).await {
Ok(pool) => {
info!("Successfully connected to PostgreSQL");
pool
}
Err(e) => {
error!("Failed to connect to PostgreSQL: {}", e);
std::process::exit(1);
}
};
let redis = match redis_helper::connect(&config.redis).await {
Ok(pool) => {
info!("Successfully connected to Redis");
pool
}
Err(e) => {
error!("Failed to connect to Redis: {}", e);
std::process::exit(1);
}
};
let state = AppState {
db,
redis,
config: config.clone(),
};
let app = Router::new()
.merge(routes::all())
.with_state(state.clone())
.with_state(state)
.layer(DynamicCors);
let host = std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".into());
let port: u16 = std::env::var("PORT")
.unwrap_or_else(|_| "3000".to_string())
.parse()
.expect("PORT must be a number");
let listener = match TcpListener::bind(config.server.bind_address).await {
Ok(listener) => listener,
Err(e) => {
error!("Failed to bind to {}: {}", config.server.bind_address, e);
std::process::exit(1);
}
};
let addr = format!("{}:{}", host, port)
.parse::<SocketAddr>()
.expect("Invalid HOST or PORT");
info!("Server listening on http://{}", config.server.bind_address);
let listener = TcpListener::bind(addr)
let shutdown_signal = async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install CTRL+C signal handler");
warn!("Shutdown signal received");
};
if let Err(err) = serve(listener, app)
.with_graceful_shutdown(shutdown_signal)
.await
.expect("Failed to bind address");
info!("Listening on http://{}", addr);
if let Err(err) = serve(listener, app).await {
{
error!("Server error: {}", err);
std::process::exit(1);
}
info!("Server has shut down gracefully");
}

View file

@ -1,5 +1,4 @@
use crate::db::AppState;
use crate::types::JsonMessage;
use axum::{
extract::{Query, State},
@ -11,7 +10,8 @@ use headers::{Cookie, HeaderMapExt};
use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use sqlx::Row;
use std::{collections::HashMap, env};
use std::collections::HashMap;
use tracing::{error, info, instrument};
use uuid::Uuid;
#[derive(Deserialize)]
@ -20,7 +20,7 @@ pub struct CallbackQuery {
state: Option<String>,
}
#[derive(Deserialize, Serialize)]
#[derive(Deserialize, Serialize, Clone)]
pub struct DiscordUser {
pub id: String,
pub username: String,
@ -34,6 +34,7 @@ pub struct AuthResponse {
session: String,
}
#[instrument(skip(state), fields(user_id))]
pub async fn get_user_from_session(
headers: &HeaderMap,
state: &AppState,
@ -56,9 +57,20 @@ pub async fn get_user_from_session(
));
};
let mut redis = state.redis.clone();
let mut redis_conn = state.redis.get_connection().await.map_err(|e| {
error!("Failed to get Redis connection: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(JsonMessage {
message: "Database connection error".into(),
}),
)
})?;
let key = format!("session:{}", session_id);
let Ok(json) = redis.get::<_, String>(&key).await else {
let json: redis::RedisResult<String> = redis_conn.as_mut().get(&key).await;
let Ok(json) = json else {
return Err((
StatusCode::UNAUTHORIZED,
Json(JsonMessage {
@ -76,32 +88,39 @@ pub async fn get_user_from_session(
));
};
tracing::Span::current().record("user_id", &user.id);
Ok(user)
}
pub async fn start_oauth(Query(params): Query<HashMap<String, String>>) -> impl IntoResponse {
let client_id = env::var("CLIENT_ID").unwrap_or_default();
let redirect_uri = env::var("REDIRECT_URI").unwrap_or_default();
#[instrument(skip(state))]
pub async fn start_oauth(
State(state): State<AppState>,
Query(params): Query<HashMap<String, String>>,
) -> impl IntoResponse {
let client_id = &state.config.discord.client_id;
let redirect_uri = &state.config.discord.redirect_uri;
let mut url = format!(
"https://discord.com/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify",
client_id, redirect_uri
);
"https://discord.com/oauth2/authorize?client_id={}&redirect_uri={}&response_type=code&scope=identify",
client_id, redirect_uri
);
if let Some(redirect) = params.get("redirect") {
url.push_str(&format!("&state={}", urlencoding::encode(redirect)));
}
info!("Starting OAuth flow");
(StatusCode::FOUND, [(axum::http::header::LOCATION, url)]).into_response()
}
#[instrument(skip(state, query), fields(user_id))]
pub async fn handle_callback(
State(state): State<AppState>,
Query(query): Query<CallbackQuery>,
) -> impl IntoResponse {
let client_id = env::var("CLIENT_ID").unwrap();
let client_secret = env::var("CLIENT_SECRET").unwrap();
let redirect_uri = env::var("REDIRECT_URI").unwrap();
let client_id = &state.config.discord.client_id;
let client_secret = &state.config.discord.client_secret;
let redirect_uri = &state.config.discord.redirect_uri;
let form = [
("client_id", client_id.as_str()),
@ -118,6 +137,7 @@ pub async fn handle_callback(
.await;
let Ok(res) = token_res else {
error!("Failed to exchange OAuth code for token");
return (
StatusCode::BAD_REQUEST,
Json(JsonMessage {
@ -128,6 +148,7 @@ pub async fn handle_callback(
};
let Ok(token_json) = res.json::<serde_json::Value>().await else {
error!("Invalid token response from Discord");
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(JsonMessage {
@ -138,6 +159,7 @@ pub async fn handle_callback(
};
let Some(access_token) = token_json["access_token"].as_str() else {
error!("Access token not found in Discord response");
return (
StatusCode::UNAUTHORIZED,
Json(JsonMessage {
@ -154,6 +176,7 @@ pub async fn handle_callback(
.await;
let Ok(user_res) = user_res else {
error!("Failed to fetch user info from Discord");
return (
StatusCode::BAD_REQUEST,
Json(JsonMessage {
@ -164,6 +187,7 @@ pub async fn handle_callback(
};
let Ok(user) = user_res.json::<DiscordUser>().await else {
error!("Failed to parse user info from Discord");
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(JsonMessage {
@ -173,15 +197,42 @@ pub async fn handle_callback(
.into_response();
};
tracing::Span::current().record("user_id", &user.id);
let session_id = Uuid::now_v7().to_string();
let mut redis = state.redis.clone();
let _ = redis
let mut redis_conn = match state.redis.get_connection().await {
Ok(conn) => conn,
Err(e) => {
error!("Failed to get Redis connection: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(JsonMessage {
message: "Database connection error".into(),
}),
)
.into_response();
}
};
if let Err(e) = redis_conn
.as_mut()
.set_ex::<_, _, ()>(
format!("session:{}", session_id),
serde_json::to_string(&user).unwrap(),
3600,
)
.await;
.await
{
error!("Failed to store session in Redis: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(JsonMessage {
message: "Failed to create session".into(),
}),
)
.into_response();
}
let redirect_target = match &query.state {
Some(s) => urlencoding::decode(s)
@ -205,9 +256,11 @@ pub async fn handle_callback(
redirect_target.parse().unwrap(),
);
info!(user_id = %user.id, username = %user.username, "User logged in successfully");
(StatusCode::FOUND, headers).into_response()
}
#[instrument(skip(state))]
pub async fn me(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
match get_user_from_session(&headers, &state).await {
Ok(user) => {
@ -236,13 +289,16 @@ pub async fn me(State(state): State<AppState>, headers: HeaderMap) -> impl IntoR
})),
)
.into_response(),
Err(_) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(JsonMessage {
message: "Failed to fetch timezone".into(),
}),
)
.into_response(),
Err(e) => {
error!("Database error while fetching timezone: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(JsonMessage {
message: "Failed to fetch timezone".into(),
}),
)
.into_response()
}
}
}
Err(err) => err.into_response(),

31
src/routes/health.rs Normal file
View file

@ -0,0 +1,31 @@
use axum::{extract::State, response::IntoResponse, Json};
use reqwest::StatusCode;
use crate::db::AppState;
pub async fn health_check(State(state): State<AppState>) -> impl IntoResponse {
let db_healthy = sqlx::query("SELECT 1").execute(&state.db).await.is_ok();
let redis_healthy = state.redis.get_connection().await.is_ok();
let status = if db_healthy && redis_healthy {
"healthy"
} else {
"unhealthy"
};
let status_code = if status == "healthy" {
StatusCode::OK
} else {
StatusCode::SERVICE_UNAVAILABLE
};
(
status_code,
Json(serde_json::json!({
"status": status,
"database": db_healthy,
"redis": redis_healthy,
"timestamp": chrono::Utc::now()
})),
)
}

View file

@ -9,6 +9,7 @@ use std::fs;
use tower_http::services::ServeDir;
mod auth;
mod health;
mod timezone;
async fn preflight_handler() -> Response {
@ -53,6 +54,7 @@ pub fn all() -> Router<AppState> {
.route("/auth/discord", get(auth::start_oauth))
.route("/auth/discord/callback", get(auth::handle_callback))
.route("/me", get(auth::me))
.route("/health", get(health::health_check))
.nest_service("/public", ServeDir::new("public"))
.fallback(get(index_page))
}

View file

@ -13,6 +13,7 @@ use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use sqlx::Row;
use std::collections::HashMap;
use tracing::error;
#[derive(Serialize)]
pub struct TimezoneResponse {
@ -132,9 +133,22 @@ pub async fn delete_timezone(
.into_response();
};
let mut redis = state.redis.clone();
let mut redis_conn = match state.redis.get_connection().await {
Ok(conn) => conn,
Err(e) => {
error!("Failed to get Redis connection: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(JsonMessage {
message: "Database connection error".into(),
}),
)
.into_response();
}
};
let key = format!("session:{}", session_id);
let json: redis::RedisResult<String> = redis.get(&key).await;
let json: redis::RedisResult<String> = redis_conn.get(&key).await;
let Ok(json) = json else {
return (
@ -204,9 +218,22 @@ pub async fn set_timezone(
.into_response();
};
let mut redis = state.redis.clone();
let mut redis_conn = match state.redis.get_connection().await {
Ok(conn) => conn,
Err(e) => {
error!("Failed to get Redis connection: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(JsonMessage {
message: "Database connection error".into(),
}),
)
.into_response();
}
};
let key = format!("session:{}", session_id);
let json: redis::RedisResult<String> = redis.get(&key).await;
let json: redis::RedisResult<String> = redis_conn.get(&key).await;
let Ok(json) = json else {
return (
@ -251,11 +278,11 @@ pub async fn set_timezone(
let result = sqlx::query(
r#"
INSERT INTO timezones (user_id, username, timezone)
VALUES ($1, $2, $3)
ON CONFLICT (user_id) DO UPDATE
SET username = EXCLUDED.username, timezone = EXCLUDED.timezone
"#,
INSERT INTO timezones (user_id, username, timezone)
VALUES ($1, $2, $3)
ON CONFLICT (user_id) DO UPDATE
SET username = EXCLUDED.username, timezone = EXCLUDED.timezone
"#,
)
.bind(&user.id)
.bind(&user.username)