From faa21a5d7e9774a1f56ab9abca3662a03cdd16f1 Mon Sep 17 00:00:00 2001 From: creations Date: Sat, 31 May 2025 19:05:33 -0400 Subject: [PATCH] this should fix cors --- Cargo.lock | 1 + Cargo.toml | 1 + src/main.rs | 9 ++++--- src/middleware/cors.rs | 60 ++++++++++++++++++++++++++++++++++++++++++ src/middleware/mod.rs | 1 + 5 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 src/middleware/cors.rs create mode 100644 src/middleware/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 372318c..8a192ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2171,6 +2171,7 @@ dependencies = [ "serde_json", "sqlx", "tokio", + "tower", "tower-http", "tracing", "tracing-subscriber", diff --git a/Cargo.toml b/Cargo.toml index 9c79317..fb99e20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,4 @@ reqwest = { version = "0.12", features = ["json", "gzip"] } tower-http = { version = "0.6.4", features = ["cors"] } headers = "0.4.0" chrono-tz = "0.10.3" +tower = "0.5.2" diff --git a/src/main.rs b/src/main.rs index a1d9e0b..e22600d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,17 @@ -use axum::{Router, serve}; +use axum::{serve, Router}; use dotenvy::dotenv; use std::net::SocketAddr; use tokio::net::TcpListener; -use tower_http::cors::CorsLayer; use tracing::{error, info}; use tracing_subscriber; mod db; +mod middleware; mod routes; mod types; -use db::{AppState, postgres, redis_helper}; +use db::{postgres, redis_helper, AppState}; +use middleware::cors::DynamicCors; #[tokio::main] async fn main() { @@ -24,7 +25,7 @@ async fn main() { let app = Router::new() .merge(routes::all()) .with_state(state.clone()) - .layer(CorsLayer::permissive()); + .layer(DynamicCors); let host = std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".into()); let port: u16 = std::env::var("PORT") diff --git a/src/middleware/cors.rs b/src/middleware/cors.rs new file mode 100644 index 0000000..1bdc66e --- /dev/null +++ b/src/middleware/cors.rs @@ -0,0 +1,60 @@ +use axum::http::{HeaderValue, Request, Response}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower::{Layer, Service}; + +#[derive(Clone)] +pub struct DynamicCors; + +impl Layer for DynamicCors { + type Service = CorsMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + CorsMiddleware { inner } + } +} + +#[derive(Clone)] +pub struct CorsMiddleware { + inner: S, +} + +impl Service> for CorsMiddleware +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + ReqBody: Send + 'static, + ResBody: 'static, +{ + type Response = Response; + type Error = S::Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let origin = req.headers().get("origin").cloned(); + let mut inner = self.inner.clone(); + + Box::pin(async move { + let mut res = inner.call(req).await?; + + if let Some(origin) = origin { + let headers = res.headers_mut(); + headers.insert("access-control-allow-origin", origin); + headers.insert( + "access-control-allow-credentials", + HeaderValue::from_static("true"), + ); + headers.insert("vary", HeaderValue::from_static("Origin")); + } + + Ok(res) + }) + } +} diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs new file mode 100644 index 0000000..3bd498a --- /dev/null +++ b/src/middleware/mod.rs @@ -0,0 +1 @@ +pub mod cors;