From 8ed77d7ab484121e9d70158e14c9fd6c243f1c70 Mon Sep 17 00:00:00 2001 From: FxQnLr Date: Mon, 12 Feb 2024 14:58:08 +0100 Subject: Close #9. Config impl with struct and files --- src/auth.rs | 13 +++---------- src/config.rs | 32 +++++++++++++++++++++++--------- src/db.rs | 16 ++-------------- src/main.rs | 14 ++++++++------ src/routes/device.rs | 6 +++--- src/routes/start.rs | 8 +++----- src/services/ping.rs | 6 +++--- 7 files changed, 45 insertions(+), 50 deletions(-) (limited to 'src') diff --git a/src/auth.rs b/src/auth.rs index 0321ade..feca652 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -2,15 +2,13 @@ use axum::http::{StatusCode, HeaderValue}; use axum::http::header::ToStrError; use tracing::{debug, error, trace}; use crate::auth::Error::{MissingSecret, WrongSecret}; -use crate::config::SETTINGS; +use crate::config::Config; -pub fn auth(secret: Option<&HeaderValue>) -> Result { +pub fn auth(config: &Config, secret: Option<&HeaderValue>) -> Result { debug!("auth request with secret {:?}", secret); if let Some(value) = secret { trace!("value exists"); - let key = SETTINGS - .get_string("apikey") - .map_err(Error::Config)?; + let key = &config.apikey; if value.to_str().map_err(Error::HeaderToStr)? == key.as_str() { debug!("successful auth"); Ok(true) @@ -28,7 +26,6 @@ pub fn auth(secret: Option<&HeaderValue>) -> Result { pub enum Error { WrongSecret, MissingSecret, - Config(config::ConfigError), HeaderToStr(ToStrError) } @@ -37,10 +34,6 @@ impl Error { match self { Self::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"), Self::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"), - Self::Config(err) => { - error!("server error: {}", err.to_string()); - (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") - }, Self::HeaderToStr(err) => { error!("server error: {}", err.to_string()); (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") diff --git a/src/config.rs b/src/config.rs index 4c79810..e88ddab 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,11 +1,25 @@ -use config::Config; -use once_cell::sync::Lazy; +use config::File; +use serde::Deserialize; -pub static SETTINGS: Lazy = Lazy::new(setup); +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + pub database_url: String, + pub apikey: String, + pub serveraddr: String, + pub pingtimeout: i64, +} + +impl Config { + pub fn load() -> Result { + let config = config::Config::builder() + .set_default("serveraddr", "0.0.0.0:7229")? + .set_default("pingtimeout", 10)? + .add_source(File::with_name("config.toml").required(false)) + .add_source(File::with_name("config.dev.toml").required(false)) + .add_source(config::Environment::with_prefix("WEBOL").separator("_")) + .build()?; + + config.try_deserialize() + } +} -fn setup() -> Config { - Config::builder() - .add_source(config::Environment::with_prefix("WEBOL").separator("_")) - .build() - .unwrap() -} \ No newline at end of file diff --git a/src/db.rs b/src/db.rs index 8a6b16e..489a000 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,13 +1,7 @@ -#[cfg(debug_assertions)] -use std::env; - use serde::Serialize; use sqlx::{PgPool, postgres::PgPoolOptions}; use tracing::{debug, info}; -#[cfg(not(debug_assertions))] -use crate::config::SETTINGS; - #[derive(Serialize, Debug)] pub struct Device { pub id: String, @@ -17,18 +11,12 @@ pub struct Device { pub times: Option> } -pub async fn init_db_pool() -> PgPool { - #[cfg(not(debug_assertions))] - let db_url = SETTINGS.get_string("database.url").unwrap(); - - #[cfg(debug_assertions)] - let db_url = env::var("DATABASE_URL").unwrap(); - +pub async fn init_db_pool(db_url: &str) -> PgPool { debug!("attempt to connect dbPool to '{}'", db_url); let pool = PgPoolOptions::new() .max_connections(5) - .connect(&db_url) + .connect(db_url) .await .unwrap(); diff --git a/src/main.rs b/src/main.rs index 9d30548..4ef129b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,4 @@ use std::env; -use std::net::SocketAddr; use std::sync::Arc; use axum::{Router, routing::post}; use axum::routing::{get, put}; @@ -9,7 +8,7 @@ use time::util::local_offset; use tokio::sync::broadcast::{channel, Sender}; use tracing::{info, level_filters::LevelFilter}; use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; -use crate::config::SETTINGS; +use crate::config::Config; use crate::db::init_db_pool; use crate::routes::device; use crate::routes::start::start; @@ -47,16 +46,18 @@ async fn main() -> color_eyre::eyre::Result<()> { let version = env!("CARGO_PKG_VERSION"); + let config = Config::load()?; + info!("start webol v{}", version); - let db = init_db_pool().await; + let db = init_db_pool(&config.database_url).await; sqlx::migrate!().run(&db).await.unwrap(); let (tx, _) = channel(32); let ping_map: StatusMap = DashMap::new(); - let shared_state = Arc::new(AppState { db, ping_send: tx, ping_map }); + let shared_state = Arc::new(AppState { db, config: config.clone(), ping_send: tx, ping_map }); let app = Router::new() .route("/start", post(start)) @@ -66,9 +67,9 @@ async fn main() -> color_eyre::eyre::Result<()> { .route("/status", get(status)) .with_state(shared_state); - let addr = SETTINGS.get_string("serveraddr").unwrap_or("0.0.0.0:7229".to_string()); + let addr = config.serveraddr; info!("start server on {}", addr); - let listener = tokio::net::TcpListener::bind(addr.parse::()?) + let listener = tokio::net::TcpListener::bind(addr) .await?; axum::serve(listener, app).await?; @@ -77,6 +78,7 @@ async fn main() -> color_eyre::eyre::Result<()> { pub struct AppState { db: PgPool, + config: Config, ping_send: Sender, ping_map: StatusMap, } diff --git a/src/routes/device.rs b/src/routes/device.rs index b80cb85..c85df1b 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs @@ -12,7 +12,7 @@ use crate::error::Error; pub async fn get(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, Error> { info!("add device {}", payload.id); let secret = headers.get("authorization"); - if auth(secret).map_err(Error::Auth)? { + if auth(&state.config, secret).map_err(Error::Auth)? { let device = sqlx::query_as!( Device, r#" @@ -39,7 +39,7 @@ pub struct GetDevicePayload { pub async fn put(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, Error> { info!("add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); let secret = headers.get("authorization"); - if auth(secret).map_err(Error::Auth)? { + if auth(&state.config, secret).map_err(Error::Auth)? { sqlx::query!( r#" INSERT INTO devices (id, mac, broadcast_addr, ip) @@ -73,7 +73,7 @@ pub struct PutDeviceResponse { pub async fn post(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, Error> { info!("edit device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); let secret = headers.get("authorization"); - if auth(secret).map_err(Error::Auth)? { + if auth(&state.config, secret).map_err(Error::Auth)? { let device = sqlx::query_as!( Device, r#" diff --git a/src/routes/start.rs b/src/routes/start.rs index 4264588..ce95bf3 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs @@ -1,5 +1,4 @@ use crate::auth::auth; -use crate::config::SETTINGS; use crate::db::Device; use crate::error::Error; use crate::services::ping::Value as PingValue; @@ -21,7 +20,7 @@ pub async fn start( ) -> Result, Error> { info!("POST request"); let secret = headers.get("authorization"); - let authorized = auth(secret).map_err(Error::Auth)?; + let authorized = auth(&state.config, secret).map_err(Error::Auth)?; if authorized { let device = sqlx::query_as!( Device, @@ -38,9 +37,7 @@ pub async fn start( info!("starting {}", device.id); - let bind_addr = SETTINGS - .get_string("bindaddr") - .unwrap_or("0.0.0.0:1111".to_string()); + let bind_addr = "0.0.0.0:0"; let _ = send_packet( &bind_addr.parse().map_err(Error::IpParse)?, @@ -75,6 +72,7 @@ pub async fn start( crate::services::ping::spawn( state.ping_send.clone(), + &state.config, device, uuid_gen.clone(), &state.ping_map, diff --git a/src/services/ping.rs b/src/services/ping.rs index 7d71218..9b164c8 100644 --- a/src/services/ping.rs +++ b/src/services/ping.rs @@ -10,7 +10,7 @@ use time::{Duration, Instant}; use tokio::sync::broadcast::Sender; use tracing::{debug, error, trace}; use crate::AppState; -use crate::config::SETTINGS; +use crate::config::Config; use crate::db::Device; pub type StatusMap = DashMap; @@ -21,7 +21,7 @@ pub struct Value { pub online: bool } -pub async fn spawn(tx: Sender, device: Device, uuid: String, ping_map: &StatusMap, db: &PgPool) { +pub async fn spawn(tx: Sender, config: &Config, device: Device, uuid: String, ping_map: &StatusMap, db: &PgPool) { let timer = Instant::now(); let payload = [0; 8]; @@ -40,7 +40,7 @@ pub async fn spawn(tx: Sender, device: Device, uuid: String, error!("{}", ping.to_string()); msg = Some(BroadcastCommands::Error(uuid.clone())); } - if timer.elapsed() >= Duration::minutes(SETTINGS.get_int("pingtimeout").unwrap_or(10)) { + if timer.elapsed() >= Duration::minutes(config.pingtimeout) { msg = Some(BroadcastCommands::Timeout(uuid.clone())); } } else { -- cgit v1.2.3