From 2f9f18b80a9e2134f674f345e48a5f21de5efadd Mon Sep 17 00:00:00 2001 From: FxQnLr Date: Sun, 18 Feb 2024 21:16:46 +0100 Subject: Refactor stuff. Use Postgres Types --- src/db.rs | 6 +- src/error.rs | 22 ++++++++ src/main.rs | 54 ++++++++++-------- src/routes.rs | 3 + src/routes/device.rs | 19 +++++-- src/routes/mod.rs | 3 - src/routes/start.rs | 78 +++++++++++++------------- src/routes/status.rs | 79 ++++++++++++++++++++++++-- src/services.rs | 1 + src/services/mod.rs | 1 - src/services/ping.rs | 154 +++++++++++++++++++++++---------------------------- 11 files changed, 255 insertions(+), 165 deletions(-) create mode 100644 src/routes.rs delete mode 100644 src/routes/mod.rs create mode 100644 src/services.rs delete mode 100644 src/services/mod.rs (limited to 'src') diff --git a/src/db.rs b/src/db.rs index 489a000..47e907d 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,13 +1,13 @@ use serde::Serialize; -use sqlx::{PgPool, postgres::PgPoolOptions}; +use sqlx::{PgPool, postgres::PgPoolOptions, types::{ipnetwork::IpNetwork, mac_address::MacAddress}}; use tracing::{debug, info}; #[derive(Serialize, Debug)] pub struct Device { pub id: String, - pub mac: String, + pub mac: MacAddress, pub broadcast_addr: String, - pub ip: String, + pub ip: IpNetwork, pub times: Option> } diff --git a/src/error.rs b/src/error.rs index 63b214e..66a61f4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,6 +2,8 @@ use axum::http::header::ToStrError; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::Json; +use ::ipnetwork::IpNetworkError; +use mac_address::MacParseError; use serde_json::json; use std::io; use tracing::error; @@ -29,6 +31,18 @@ pub enum Error { source: ToStrError, }, + #[error("string parse: {source}")] + IpParse { + #[from] + source: IpNetworkError, + }, + + #[error("mac parse: {source}")] + MacParse { + #[from] + source: MacParseError, + }, + #[error("io: {source}")] Io { #[from] @@ -57,6 +71,14 @@ impl IntoResponse for Error { error!("{source}"); (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") } + Self::MacParse { source } => { + error!("{source}"); + (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") + } + Self::IpParse { source } => { + error!("{source}"); + (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") + } }; let body = Json(json!({ "error": error_message, diff --git a/src/main.rs b/src/main.rs index 4ef129b..7d8c1da 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,42 +1,44 @@ -use std::env; -use std::sync::Arc; -use axum::{Router, routing::post}; -use axum::routing::{get, put}; -use dashmap::DashMap; -use sqlx::PgPool; -use time::util::local_offset; -use tokio::sync::broadcast::{channel, Sender}; -use tracing::{info, level_filters::LevelFilter}; -use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; use crate::config::Config; use crate::db::init_db_pool; use crate::routes::device; use crate::routes::start::start; use crate::routes::status::status; -use crate::services::ping::{BroadcastCommands, StatusMap}; +use crate::services::ping::StatusMap; +use axum::routing::{get, put}; +use axum::{routing::post, Router}; +use dashmap::DashMap; +use services::ping::BroadcastCommand; +use sqlx::PgPool; +use tracing_subscriber::fmt::time::UtcTime; +use std::env; +use std::sync::Arc; +use tokio::sync::broadcast::{channel, Sender}; +use tracing::{info, level_filters::LevelFilter}; +use tracing_subscriber::{ + fmt, + prelude::*, + EnvFilter, +}; mod auth; mod config; -mod routes; -mod wol; mod db; mod error; +mod routes; mod services; +mod wol; #[tokio::main] async fn main() -> color_eyre::eyre::Result<()> { - color_eyre::install()?; + - unsafe { local_offset::set_soundness(local_offset::Soundness::Unsound); } let time_format = time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"); - let loc = LocalTime::new(time_format); + let loc = UtcTime::new(time_format); tracing_subscriber::registry() - .with(fmt::layer() - .with_timer(loc) - ) + .with(fmt::layer().with_timer(loc)) .with( EnvFilter::builder() .with_default_directive(LevelFilter::INFO.into()) @@ -56,8 +58,13 @@ async fn main() -> color_eyre::eyre::Result<()> { let (tx, _) = channel(32); let ping_map: StatusMap = DashMap::new(); - - let shared_state = Arc::new(AppState { db, config: config.clone(), ping_send: tx, ping_map }); + + let shared_state = Arc::new(AppState { + db, + config: config.clone(), + ping_send: tx, + ping_map, + }); let app = Router::new() .route("/start", post(start)) @@ -69,8 +76,7 @@ async fn main() -> color_eyre::eyre::Result<()> { let addr = config.serveraddr; info!("start server on {}", addr); - let listener = tokio::net::TcpListener::bind(addr) - .await?; + let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve(listener, app).await?; Ok(()) @@ -79,6 +85,6 @@ async fn main() -> color_eyre::eyre::Result<()> { pub struct AppState { db: PgPool, config: Config, - ping_send: Sender, + ping_send: Sender, ping_map: StatusMap, } diff --git a/src/routes.rs b/src/routes.rs new file mode 100644 index 0000000..d5ab0d6 --- /dev/null +++ b/src/routes.rs @@ -0,0 +1,3 @@ +pub mod start; +pub mod device; +pub mod status; \ No newline at end of file diff --git a/src/routes/device.rs b/src/routes/device.rs index 5ca574a..2f0093d 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs @@ -4,9 +4,11 @@ use crate::error::Error; use axum::extract::State; use axum::http::HeaderMap; use axum::Json; +use mac_address::MacAddress; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use std::sync::Arc; +use sqlx::types::ipnetwork::IpNetwork; +use std::{sync::Arc, str::FromStr}; use tracing::{debug, info}; pub async fn get( @@ -14,7 +16,7 @@ pub async fn get( headers: HeaderMap, Json(payload): Json, ) -> Result, Error> { - info!("add device {}", payload.id); + info!("get device {}", payload.id); let secret = headers.get("authorization"); let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); if authorized { @@ -52,18 +54,21 @@ pub async fn put( "add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip ); + let secret = headers.get("authorization"); let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); if authorized { + let ip = IpNetwork::from_str(&payload.ip)?; + let mac = MacAddress::from_str(&payload.mac)?; sqlx::query!( r#" INSERT INTO devices (id, mac, broadcast_addr, ip) VALUES ($1, $2, $3, $4); "#, payload.id, - payload.mac, + mac, payload.broadcast_addr, - payload.ip + ip ) .execute(&state.db) .await?; @@ -99,6 +104,8 @@ pub async fn post( let secret = headers.get("authorization"); let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); if authorized { + let ip = IpNetwork::from_str(&payload.ip)?; + let mac = MacAddress::from_str(&payload.mac)?; let device = sqlx::query_as!( Device, r#" @@ -106,9 +113,9 @@ pub async fn post( SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 RETURNING id, mac, broadcast_addr, ip, times; "#, - payload.mac, + mac, payload.broadcast_addr, - payload.ip, + ip, payload.id ) .fetch_one(&state.db) diff --git a/src/routes/mod.rs b/src/routes/mod.rs deleted file mode 100644 index d5ab0d6..0000000 --- a/src/routes/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod start; -pub mod device; -pub mod status; \ No newline at end of file diff --git a/src/routes/start.rs b/src/routes/start.rs index ec4f98f..4888325 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs @@ -12,7 +12,6 @@ use std::sync::Arc; use tracing::{debug, info}; use uuid::Uuid; -#[axum_macros::debug_handler] pub async fn start( State(state): State>, headers: HeaderMap, @@ -41,45 +40,11 @@ pub async fn start( let _ = send_packet( bind_addr, &device.broadcast_addr, - &create_buffer(&device.mac)?, + &create_buffer(&device.mac.to_string())?, )?; let dev_id = device.id.clone(); let uuid = if payload.ping.is_some_and(|ping| ping) { - let mut uuid: Option = None; - for (key, value) in state.ping_map.clone() { - if value.ip == device.ip { - debug!("service already exists"); - uuid = Some(key); - break; - } - } - let uuid_gen = match uuid { - Some(u) => u, - None => Uuid::new_v4().to_string(), - }; - let uuid_genc = uuid_gen.clone(); - - tokio::spawn(async move { - debug!("init ping service"); - state.ping_map.insert( - uuid_gen.clone(), - PingValue { - ip: device.ip.clone(), - online: false, - }, - ); - - crate::services::ping::spawn( - state.ping_send.clone(), - &state.config, - device, - uuid_gen.clone(), - &state.ping_map, - &state.db, - ) - .await; - }); - Some(uuid_genc) + Some(setup_ping(state, device)) } else { None }; @@ -93,6 +58,45 @@ pub async fn start( } } +fn setup_ping(state: Arc, device: Device) -> String { + let mut uuid: Option = None; + for (key, value) in state.ping_map.clone() { + if value.ip == device.ip { + debug!("service already exists"); + uuid = Some(key); + break; + } + } + let uuid_gen = match uuid { + Some(u) => u, + None => Uuid::new_v4().to_string(), + }; + let uuid_ret = uuid_gen.clone(); + + debug!("init ping service"); + state.ping_map.insert( + uuid_gen.clone(), + PingValue { + ip: device.ip, + online: false, + }, + ); + + tokio::spawn(async move { + crate::services::ping::spawn( + state.ping_send.clone(), + &state.config, + device, + uuid_gen, + &state.ping_map, + &state.db, + ) + .await; + }); + + uuid_ret +} + #[derive(Deserialize)] pub struct Payload { id: String, diff --git a/src/routes/status.rs b/src/routes/status.rs index 31ef996..0e25f7d 100644 --- a/src/routes/status.rs +++ b/src/routes/status.rs @@ -1,10 +1,79 @@ -use std::sync::Arc; +use crate::services::ping::BroadcastCommand; +use crate::AppState; +use axum::extract::ws::{Message, WebSocket}; use axum::extract::{State, WebSocketUpgrade}; use axum::response::Response; -use crate::AppState; -use crate::services::ping::status_websocket; +use sqlx::PgPool; +use std::sync::Arc; +use tracing::{debug, trace}; -#[axum_macros::debug_handler] pub async fn status(State(state): State>, ws: WebSocketUpgrade) -> Response { - ws.on_upgrade(move |socket| status_websocket(socket, state)) + ws.on_upgrade(move |socket| websocket(socket, state)) +} + +pub async fn websocket(mut socket: WebSocket, state: Arc) { + trace!("wait for ws message (uuid)"); + let msg = socket.recv().await; + let uuid = msg.unwrap().unwrap().into_text().unwrap(); + + trace!("Search for uuid: {}", uuid); + + let eta = get_eta(&state.db).await; + let _ = socket + .send(Message::Text(format!("eta_{eta}_{uuid}"))) + .await; + + let device_exists = state.ping_map.contains_key(&uuid); + if device_exists { + let _ = socket + .send(receive_ping_broadcast(state.clone(), uuid).await) + .await; + } else { + debug!("didn't find any device"); + let _ = socket.send(Message::Text(format!("notfound_{uuid}"))).await; + }; + + let _ = socket.close().await; +} + +async fn receive_ping_broadcast(state: Arc, uuid: String) -> Message { + let pm = state.ping_map.clone().into_read_only(); + let device = pm.get(&uuid).expect("fatal error"); + debug!("got device: {} (online: {})", device.ip, device.online); + if device.online { + debug!("already started"); + Message::Text(BroadcastCommand::success(uuid).to_string()) + } else { + loop { + trace!("wait for tx message"); + let message = state + .ping_send + .subscribe() + .recv() + .await + .expect("fatal error"); + trace!("got message {:?}", message); + + if message.uuid != uuid { + continue; + } + trace!("message == uuid success"); + return Message::Text(message.to_string()); + } + } +} + +async fn get_eta(db: &PgPool) -> i64 { + let query = sqlx::query!(r#"SELECT times FROM devices;"#) + .fetch_one(db) + .await + .unwrap(); + + let times = if let Some(times) = query.times { + times + } else { + vec![0] + }; + + times.iter().sum::() / i64::try_from(times.len()).unwrap() } diff --git a/src/services.rs b/src/services.rs new file mode 100644 index 0000000..a766209 --- /dev/null +++ b/src/services.rs @@ -0,0 +1 @@ +pub mod ping; diff --git a/src/services/mod.rs b/src/services/mod.rs deleted file mode 100644 index a766209..0000000 --- a/src/services/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod ping; diff --git a/src/services/ping.rs b/src/services/ping.rs index 9b164c8..9191f86 100644 --- a/src/services/ping.rs +++ b/src/services/ping.rs @@ -1,59 +1,58 @@ -use std::str::FromStr; -use std::net::IpAddr; -use std::sync::Arc; - -use axum::extract::ws::WebSocket; -use axum::extract::ws::Message; +use crate::config::Config; +use crate::db::Device; use dashmap::DashMap; +use ipnetwork::IpNetwork; use sqlx::PgPool; +use std::fmt::Display; use time::{Duration, Instant}; use tokio::sync::broadcast::Sender; use tracing::{debug, error, trace}; -use crate::AppState; -use crate::config::Config; -use crate::db::Device; pub type StatusMap = DashMap; #[derive(Debug, Clone)] pub struct Value { - pub ip: String, - pub online: bool + pub ip: IpNetwork, + pub online: bool, } -pub async fn spawn(tx: Sender, config: &Config, device: Device, uuid: String, ping_map: &StatusMap, db: &PgPool) { +pub async fn spawn( + tx: Sender, + config: &Config, + device: Device, + uuid: String, + ping_map: &StatusMap, + db: &PgPool, +) { let timer = Instant::now(); let payload = [0; 8]; - let ping_ip = IpAddr::from_str(&device.ip).expect("bad ip"); - - let mut msg: Option = None; + let mut msg: Option = None; while msg.is_none() { - let ping = surge_ping::ping( - ping_ip, - &payload - ).await; + let ping = surge_ping::ping(device.ip.ip(), &payload).await; if let Err(ping) = ping { let ping_timeout = matches!(ping, surge_ping::SurgeError::Timeout { .. }); if !ping_timeout { error!("{}", ping.to_string()); - msg = Some(BroadcastCommands::Error(uuid.clone())); + msg = Some(BroadcastCommand::error(uuid.clone())); } if timer.elapsed() >= Duration::minutes(config.pingtimeout) { - msg = Some(BroadcastCommands::Timeout(uuid.clone())); + msg = Some(BroadcastCommand::timeout(uuid.clone())); } } else { - let (_, duration) = ping.map_err(|err| error!("{}", err.to_string())).expect("fatal error"); + let (_, duration) = ping + .map_err(|err| error!("{}", err.to_string())) + .expect("fatal error"); debug!("ping took {:?}", duration); - msg = Some(BroadcastCommands::Success(uuid.clone())); + msg = Some(BroadcastCommand::success(uuid.clone())); }; } let msg = msg.expect("fatal error"); let _ = tx.send(msg.clone()); - if let BroadcastCommands::Success(..) = msg { + if let BroadcastCommands::Success = msg.command { sqlx::query!( r#" UPDATE devices @@ -62,8 +61,17 @@ pub async fn spawn(tx: Sender, config: &Config, device: Devic "#, timer.elapsed().whole_seconds(), device.id - ).execute(db).await.unwrap(); - ping_map.insert(uuid.clone(), Value { ip: device.ip.clone(), online: true }); + ) + .execute(db) + .await + .unwrap(); + ping_map.insert( + uuid.clone(), + Value { + ip: device.ip, + online: true, + }, + ); tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; } trace!("remove {} from ping_map", uuid); @@ -72,74 +80,48 @@ pub async fn spawn(tx: Sender, config: &Config, device: Devic #[derive(Clone, Debug, PartialEq)] pub enum BroadcastCommands { - Success(String), - Timeout(String), - Error(String), + Success, + Timeout, + Error, } -pub async fn status_websocket(mut socket: WebSocket, state: Arc) { - trace!("wait for ws message (uuid)"); - let msg = socket.recv().await; - let uuid = msg.unwrap().unwrap().into_text().unwrap(); - - trace!("Search for uuid: {}", uuid); - - let eta = get_eta(&state.db).await; - let _ = socket.send(Message::Text(format!("eta_{eta}_{uuid}"))).await; +#[derive(Clone, Debug, PartialEq)] +pub struct BroadcastCommand { + pub uuid: String, + pub command: BroadcastCommands, +} - let device_exists = state.ping_map.contains_key(&uuid); - if device_exists { - let _ = socket.send(process_device(state.clone(), uuid).await).await; - } else { - debug!("didn't find any device"); - let _ = socket.send(Message::Text(format!("notfound_{uuid}"))).await; - }; +impl Display for BroadcastCommand { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let prefix = match self.command { + BroadcastCommands::Success => "start", + BroadcastCommands::Timeout => "timeout", + BroadcastCommands::Error => "error", + }; - let _ = socket.close().await; + f.write_str(format!("{prefix}_{}", self.uuid).as_str()) + } } -async fn get_eta(db: &PgPool) -> i64 { - let query = sqlx::query!( - r#"SELECT times FROM devices;"# - ).fetch_one(db).await.unwrap(); - - let times = match query.times { - None => { vec![0] }, - Some(t) => t, - }; - times.iter().sum::() / i64::try_from(times.len()).unwrap() +impl BroadcastCommand { + pub fn success(uuid: String) -> Self { + Self { + uuid, + command: BroadcastCommands::Success, + } + } -} + pub fn timeout(uuid: String) -> Self { + Self { + uuid, + command: BroadcastCommands::Timeout, + } + } -async fn process_device(state: Arc, uuid: String) -> Message { - let pm = state.ping_map.clone().into_read_only(); - let device = pm.get(&uuid).expect("fatal error"); - debug!("got device: {} (online: {})", device.ip, device.online); - if device.online { - debug!("already started"); - Message::Text(format!("start_{uuid}")) - } else { - loop { - trace!("wait for tx message"); - let message = state.ping_send.subscribe().recv().await.expect("fatal error"); - trace!("got message {:?}", message); - return match message { - BroadcastCommands::Success(msg_uuid) => { - if msg_uuid != uuid { continue; } - trace!("message == uuid success"); - Message::Text(format!("start_{uuid}")) - }, - BroadcastCommands::Timeout(msg_uuid) => { - if msg_uuid != uuid { continue; } - trace!("message == uuid timeout"); - Message::Text(format!("timeout_{uuid}")) - }, - BroadcastCommands::Error(msg_uuid) => { - if msg_uuid != uuid { continue; } - trace!("message == uuid error"); - Message::Text(format!("error_{uuid}")) - } - } + pub fn error(uuid: String) -> Self { + Self { + uuid, + command: BroadcastCommands::Error, } } } -- cgit v1.2.3