From dcfb83fb2069bfcf4642b03453253e35479bf3da Mon Sep 17 00:00:00 2001 From: fx Date: Tue, 24 Oct 2023 01:15:22 +0200 Subject: first ping impl baseline, doesnt work --- src/auth.rs | 21 +++++++++++++-------- src/error.rs | 37 ++++++++++++++++++++++++++++++------ src/main.rs | 19 +++++++++++++++++-- src/routes/device.rs | 8 ++++---- src/routes/start.rs | 21 +++++++++++++-------- src/services/mod.rs | 1 + src/services/ping.rs | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/wol.rs | 17 +++++++++-------- 8 files changed, 141 insertions(+), 36 deletions(-) create mode 100644 src/services/mod.rs create mode 100644 src/services/ping.rs (limited to 'src') diff --git a/src/auth.rs b/src/auth.rs index 0fffa60..e4b1c2f 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,8 +1,8 @@ -use std::error::Error; use axum::headers::HeaderValue; use axum::http::StatusCode; +use axum::http::header::ToStrError; use tracing::{debug, error, trace}; -use crate::auth::AuthError::{MissingSecret, ServerError, WrongSecret}; +use crate::auth::AuthError::{MissingSecret, WrongSecret}; use crate::config::SETTINGS; pub fn auth(secret: Option<&HeaderValue>) -> Result { @@ -11,8 +11,8 @@ pub fn auth(secret: Option<&HeaderValue>) -> Result { trace!("value exists"); let key = SETTINGS .get_string("apikey") - .map_err(|err| ServerError(Box::new(err)))?; - if value.to_str().map_err(|err| ServerError(Box::new(err)))? == key.as_str() { + .map_err(AuthError::Config)?; + if value.to_str().map_err(AuthError::HeaderToStr)? == key.as_str() { debug!("successful auth"); Ok(true) } else { @@ -29,15 +29,20 @@ pub fn auth(secret: Option<&HeaderValue>) -> Result { pub enum AuthError { WrongSecret, MissingSecret, - ServerError(Box), + Config(config::ConfigError), + HeaderToStr(ToStrError) } impl AuthError { pub fn get(self) -> (StatusCode, &'static str) { match self { - AuthError::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"), - AuthError::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"), - AuthError::ServerError(err) => { + Self::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"), + Self::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"), + Self::Config(err) => { + error!("server error: {}", err.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") + }, + Self::HeaderToStr(err) => { error!("server error: {}", err.to_string()); (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") }, diff --git a/src/error.rs b/src/error.rs index db2fc86..f143ee9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,5 @@ use std::error::Error; +use std::io; use axum::http::StatusCode; use axum::Json; use axum::response::{IntoResponse, Response}; @@ -8,21 +9,45 @@ use crate::auth::AuthError; #[derive(Debug)] pub enum WebolError { - Auth(AuthError), Generic, - Server(Box), + Auth(AuthError), + Ping(surge_ping::SurgeError), + DB(sqlx::Error), + IpParse(::Err), + BufferParse(std::num::ParseIntError), + Broadcast(io::Error), + Axum(axum::Error) } impl IntoResponse for WebolError { fn into_response(self) -> Response { let (status, error_message) = match self { - WebolError::Auth(err) => err.get(), - WebolError::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), - WebolError::Server(err) => { + Self::Auth(err) => err.get(), + Self::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), + Self::Ping(err) => { + error!("Ping: {}", err.source().unwrap()); + (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") + }, + Self::IpParse(err) => { + error!("server error: {}", err.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") + }, + Self::DB(err) => { + error!("server error: {}", err.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") + }, + Self::Broadcast(err) => { + error!("server error: {}", err.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") + }, + Self::BufferParse(err) => { + error!("server error: {}", err.to_string()); + (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") + }, + Self::Axum(err) => { error!("server error: {}", err.to_string()); (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") }, - }; let body = Json(json!({ "error": error_message, diff --git a/src/main.rs b/src/main.rs index ce12cf6..9c31ec8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,12 +4,14 @@ use axum::{Router, routing::post}; use axum::routing::{get, put}; use sqlx::PgPool; use time::util::local_offset; +use tokio::sync::mpsc::{self, Sender}; use tracing::{info, level_filters::LevelFilter}; use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; use crate::config::SETTINGS; use crate::db::init_db_pool; use crate::routes::device::{get_device, post_device, put_device}; use crate::routes::start::start; +use crate::services::ping::ws_ping; mod auth; mod config; @@ -17,6 +19,7 @@ mod routes; mod wol; mod db; mod error; +mod services; #[tokio::main] async fn main() { @@ -43,13 +46,23 @@ async fn main() { let db = init_db_pool().await; sqlx::migrate!().run(&db).await.unwrap(); - let shared_state = Arc::new(AppState { db }); + let (tx, mut rx) = mpsc::channel(32); + + // FIXME: once_cell? or just static mutable + tokio::spawn( async move { + while let Some(message) = rx.recv().await { + println!("GOT = {}", message); + } + }); + + let shared_state = Arc::new(AppState { db, ping_send: tx }); let app = Router::new() .route("/start", post(start)) .route("/device", get(get_device)) .route("/device", put(put_device)) .route("/device", post(post_device)) + .route("/status", get(ws_ping)) .with_state(shared_state); let addr = SETTINGS.get_string("serveraddr").unwrap_or("0.0.0.0:7229".to_string()); @@ -61,5 +74,7 @@ async fn main() { } pub struct AppState { - db: PgPool + db: PgPool, + ping_send: Sender, + // ping_receive: Receiver } diff --git a/src/routes/device.rs b/src/routes/device.rs index 025c7d0..248d1e0 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs @@ -21,7 +21,7 @@ pub async fn get_device(State(state): State>, headers: Head WHERE id = $1; "#, payload.id - ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; + ).fetch_one(&state.db).await.map_err(WebolError::DB)?; Ok(Json(json!(device))) } else { @@ -46,7 +46,7 @@ pub async fn put_device(State(state): State>, headers: Head payload.id, payload.mac, payload.broadcast_addr - ).execute(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; + ).execute(&state.db).await.map_err(WebolError::DB)?; Ok(Json(json!(PutDeviceResponse { success: true }))) } else { @@ -80,7 +80,7 @@ pub async fn post_device(State(state): State>, headers: Hea payload.mac, payload.broadcast_addr, payload.id - ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; + ).fetch_one(&state.db).await.map_err(WebolError::DB)?; Ok(Json(json!(device))) } else { @@ -93,4 +93,4 @@ pub struct PostDevicePayload { id: String, mac: String, broadcast_addr: String, -} \ No newline at end of file +} diff --git a/src/routes/start.rs b/src/routes/start.rs index 163d58c..b45fe5b 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs @@ -14,7 +14,8 @@ use crate::error::WebolError; pub async fn start(State(state): State>, headers: HeaderMap, Json(payload): Json) -> Result, WebolError> { info!("POST request"); let secret = headers.get("authorization"); - if auth(secret).map_err(WebolError::Auth)? { + let authorized = auth(secret).map_err(WebolError::Auth)?; + if authorized { let device = sqlx::query_as!( Device, r#" @@ -23,7 +24,7 @@ pub async fn start(State(state): State>, headers: HeaderMap WHERE id = $1; "#, payload.id - ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; + ).fetch_one(&state.db).await.map_err(WebolError::DB)?; info!("starting {}", device.id); @@ -32,10 +33,14 @@ pub async fn start(State(state): State>, headers: HeaderMap .unwrap_or("0.0.0.0:1111".to_string()); let _ = send_packet( - &bind_addr.parse().map_err(|err| WebolError::Server(Box::new(err)))?, - &device.broadcast_addr.parse().map_err(|err| WebolError::Server(Box::new(err)))?, - create_buffer(&device.mac).map_err(|err| WebolError::Server(Box::new(err)))? - ).map_err(|err| WebolError::Server(Box::new(err))); + &bind_addr.parse().map_err(WebolError::IpParse)?, + &device.broadcast_addr.parse().map_err(WebolError::IpParse)?, + create_buffer(&device.mac)? + )?; + + if payload.ping.is_some_and(|ping| ping) { + tokio::spawn(async move {crate::services::ping::spawn(state.ping_send.clone()).await}); + } Ok(Json(json!(StartResponse { id: device.id, boot: true }))) } else { Err(WebolError::Generic) @@ -45,11 +50,11 @@ pub async fn start(State(state): State>, headers: HeaderMap #[derive(Deserialize)] pub struct StartPayload { id: String, - _test: Option, + ping: Option, } #[derive(Serialize)] struct StartResponse { id: String, boot: bool, -} \ No newline at end of file +} diff --git a/src/services/mod.rs b/src/services/mod.rs new file mode 100644 index 0000000..a766209 --- /dev/null +++ b/src/services/mod.rs @@ -0,0 +1 @@ +pub mod ping; diff --git a/src/services/ping.rs b/src/services/ping.rs new file mode 100644 index 0000000..6e710ec --- /dev/null +++ b/src/services/ping.rs @@ -0,0 +1,53 @@ +use std::sync::Arc; + +use axum::{extract::{WebSocketUpgrade, ws::WebSocket, State}, response::Response}; +use tokio::sync::mpsc::Sender; +use tracing::{debug, error}; + +use crate::{error::WebolError, AppState}; + +pub async fn spawn(tx: Sender) -> Result<(), WebolError> { + let payload = [0; 8]; + + let mut cont = true; + while cont { + let ping = surge_ping::ping( + "192.168.178.28".parse().map_err(WebolError::IpParse)?, + &payload + ).await; + + if let Err(ping) = ping { + cont = matches!(ping, surge_ping::SurgeError::Timeout { .. }); + + debug!("{}", cont); + + if !cont { + return Err(ping).map_err(WebolError::Ping) + } + + } else { + let (_, duration) = ping.unwrap(); + debug!("Ping took {:?}", duration); + cont = false; + // FIXME: remove unwrap + tx.send("Got ping".to_string()).await.unwrap(); + }; + } + + Ok(()) +} + +pub async fn ws_ping(ws: WebSocketUpgrade, State(_state): State>) -> Response { + ws.on_upgrade(handle_socket) +} + +// FIXME: Handle commands through enum +async fn handle_socket(mut socket: WebSocket) { + // TODO: Understand Cow + + // match socket.send(axum::extract::ws::Message::Close(Some(CloseFrame { code: 4000, reason: Cow::Owned("started".to_owned()) }))).await.map_err(WebolError::Axum) { + match socket.send(axum::extract::ws::Message::Text("started".to_string())).await.map_err(WebolError::Axum) { + Ok(..) => (), + Err(err) => { error!("Server Error: {:?}", err) } + }; +} diff --git a/src/wol.rs b/src/wol.rs index 80b66cd..0cdcae3 100644 --- a/src/wol.rs +++ b/src/wol.rs @@ -1,16 +1,17 @@ use std::net::{SocketAddr, UdpSocket}; -use std::num::ParseIntError; + +use crate::error::WebolError; /// Creates the magic packet from a mac address /// /// # Panics /// /// Panics if `mac_addr` is an invalid mac -pub fn create_buffer(mac_addr: &str) -> Result, ParseIntError> { +pub fn create_buffer(mac_addr: &str) -> Result, WebolError> { let mut mac = Vec::new(); let sp = mac_addr.split(':'); for f in sp { - mac.push(u8::from_str_radix(f, 16)?); + mac.push(u8::from_str_radix(f, 16).map_err(WebolError::BufferParse)?) }; let mut buf = vec![255; 6]; for _ in 0..16 { @@ -22,8 +23,8 @@ pub fn create_buffer(mac_addr: &str) -> Result, ParseIntError> { } /// Sends a buffer on UDP broadcast -pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: Vec) -> Result { - let socket = UdpSocket::bind(bind_addr)?; - socket.set_broadcast(true)?; - socket.send_to(&buffer, broadcast_addr) -} \ No newline at end of file +pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: Vec) -> Result { + let socket = UdpSocket::bind(bind_addr).map_err(WebolError::Broadcast)?; + socket.set_broadcast(true).map_err(WebolError::Broadcast)?; + socket.send_to(&buffer, broadcast_addr).map_err(WebolError::Broadcast) +} -- cgit v1.2.3