this should fix cors
This commit is contained in:
parent
2da598738b
commit
faa21a5d7e
5 changed files with 68 additions and 4 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -2171,6 +2171,7 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
|
|
|
@ -18,3 +18,4 @@ reqwest = { version = "0.12", features = ["json", "gzip"] }
|
||||||
tower-http = { version = "0.6.4", features = ["cors"] }
|
tower-http = { version = "0.6.4", features = ["cors"] }
|
||||||
headers = "0.4.0"
|
headers = "0.4.0"
|
||||||
chrono-tz = "0.10.3"
|
chrono-tz = "0.10.3"
|
||||||
|
tower = "0.5.2"
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
use axum::{Router, serve};
|
use axum::{serve, Router};
|
||||||
use dotenvy::dotenv;
|
use dotenvy::dotenv;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tower_http::cors::CorsLayer;
|
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
use tracing_subscriber;
|
use tracing_subscriber;
|
||||||
|
|
||||||
mod db;
|
mod db;
|
||||||
|
mod middleware;
|
||||||
mod routes;
|
mod routes;
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
use db::{AppState, postgres, redis_helper};
|
use db::{postgres, redis_helper, AppState};
|
||||||
|
use middleware::cors::DynamicCors;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
|
@ -24,7 +25,7 @@ async fn main() {
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(routes::all())
|
.merge(routes::all())
|
||||||
.with_state(state.clone())
|
.with_state(state.clone())
|
||||||
.layer(CorsLayer::permissive());
|
.layer(DynamicCors);
|
||||||
|
|
||||||
let host = std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".into());
|
let host = std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".into());
|
||||||
let port: u16 = std::env::var("PORT")
|
let port: u16 = std::env::var("PORT")
|
||||||
|
|
60
src/middleware/cors.rs
Normal file
60
src/middleware/cors.rs
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
use axum::http::{HeaderValue, Request, Response};
|
||||||
|
use std::{
|
||||||
|
future::Future,
|
||||||
|
pin::Pin,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
use tower::{Layer, Service};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct DynamicCors;
|
||||||
|
|
||||||
|
impl<S> Layer<S> for DynamicCors {
|
||||||
|
type Service = CorsMiddleware<S>;
|
||||||
|
|
||||||
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
|
CorsMiddleware { inner }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct CorsMiddleware<S> {
|
||||||
|
inner: S,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for CorsMiddleware<S>
|
||||||
|
where
|
||||||
|
S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
|
||||||
|
S::Future: Send + 'static,
|
||||||
|
ReqBody: Send + 'static,
|
||||||
|
ResBody: 'static,
|
||||||
|
{
|
||||||
|
type Response = Response<ResBody>;
|
||||||
|
type Error = S::Error;
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.inner.poll_ready(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||||
|
let origin = req.headers().get("origin").cloned();
|
||||||
|
let mut inner = self.inner.clone();
|
||||||
|
|
||||||
|
Box::pin(async move {
|
||||||
|
let mut res = inner.call(req).await?;
|
||||||
|
|
||||||
|
if let Some(origin) = origin {
|
||||||
|
let headers = res.headers_mut();
|
||||||
|
headers.insert("access-control-allow-origin", origin);
|
||||||
|
headers.insert(
|
||||||
|
"access-control-allow-credentials",
|
||||||
|
HeaderValue::from_static("true"),
|
||||||
|
);
|
||||||
|
headers.insert("vary", HeaderValue::from_static("Origin"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
1
src/middleware/mod.rs
Normal file
1
src/middleware/mod.rs
Normal file
|
@ -0,0 +1 @@
|
||||||
|
pub mod cors;
|
Loading…
Add table
Add a link
Reference in a new issue