diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/auth.rs | 21 | ||||
-rw-r--r-- | src/error.rs | 37 | ||||
-rw-r--r-- | src/main.rs | 19 | ||||
-rw-r--r-- | src/routes/device.rs | 8 | ||||
-rw-r--r-- | src/routes/start.rs | 21 | ||||
-rw-r--r-- | src/services/mod.rs | 1 | ||||
-rw-r--r-- | src/services/ping.rs | 53 | ||||
-rw-r--r-- | src/wol.rs | 17 |
8 files changed, 141 insertions, 36 deletions
diff --git a/src/auth.rs b/src/auth.rs index 0fffa60..e4b1c2f 100644 --- a/src/auth.rs +++ b/src/auth.rs | |||
@@ -1,8 +1,8 @@ | |||
1 | use std::error::Error; | ||
2 | use axum::headers::HeaderValue; | 1 | use axum::headers::HeaderValue; |
3 | use axum::http::StatusCode; | 2 | use axum::http::StatusCode; |
3 | use axum::http::header::ToStrError; | ||
4 | use tracing::{debug, error, trace}; | 4 | use tracing::{debug, error, trace}; |
5 | use crate::auth::AuthError::{MissingSecret, ServerError, WrongSecret}; | 5 | use crate::auth::AuthError::{MissingSecret, WrongSecret}; |
6 | use crate::config::SETTINGS; | 6 | use crate::config::SETTINGS; |
7 | 7 | ||
8 | pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> { | 8 | pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> { |
@@ -11,8 +11,8 @@ pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> { | |||
11 | trace!("value exists"); | 11 | trace!("value exists"); |
12 | let key = SETTINGS | 12 | let key = SETTINGS |
13 | .get_string("apikey") | 13 | .get_string("apikey") |
14 | .map_err(|err| ServerError(Box::new(err)))?; | 14 | .map_err(AuthError::Config)?; |
15 | if value.to_str().map_err(|err| ServerError(Box::new(err)))? == key.as_str() { | 15 | if value.to_str().map_err(AuthError::HeaderToStr)? == key.as_str() { |
16 | debug!("successful auth"); | 16 | debug!("successful auth"); |
17 | Ok(true) | 17 | Ok(true) |
18 | } else { | 18 | } else { |
@@ -29,15 +29,20 @@ pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> { | |||
29 | pub enum AuthError { | 29 | pub enum AuthError { |
30 | WrongSecret, | 30 | WrongSecret, |
31 | MissingSecret, | 31 | MissingSecret, |
32 | ServerError(Box<dyn Error>), | 32 | Config(config::ConfigError), |
33 | HeaderToStr(ToStrError) | ||
33 | } | 34 | } |
34 | 35 | ||
35 | impl AuthError { | 36 | impl AuthError { |
36 | pub fn get(self) -> (StatusCode, &'static str) { | 37 | pub fn get(self) -> (StatusCode, &'static str) { |
37 | match self { | 38 | match self { |
38 | AuthError::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"), | 39 | Self::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"), |
39 | AuthError::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"), | 40 | Self::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"), |
40 | AuthError::ServerError(err) => { | 41 | Self::Config(err) => { |
42 | error!("server error: {}", err.to_string()); | ||
43 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
44 | }, | ||
45 | Self::HeaderToStr(err) => { | ||
41 | error!("server error: {}", err.to_string()); | 46 | error!("server error: {}", err.to_string()); |
42 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | 47 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") |
43 | }, | 48 | }, |
diff --git a/src/error.rs b/src/error.rs index db2fc86..f143ee9 100644 --- a/src/error.rs +++ b/src/error.rs | |||
@@ -1,4 +1,5 @@ | |||
1 | use std::error::Error; | 1 | use std::error::Error; |
2 | use std::io; | ||
2 | use axum::http::StatusCode; | 3 | use axum::http::StatusCode; |
3 | use axum::Json; | 4 | use axum::Json; |
4 | use axum::response::{IntoResponse, Response}; | 5 | use axum::response::{IntoResponse, Response}; |
@@ -8,21 +9,45 @@ use crate::auth::AuthError; | |||
8 | 9 | ||
9 | #[derive(Debug)] | 10 | #[derive(Debug)] |
10 | pub enum WebolError { | 11 | pub enum WebolError { |
11 | Auth(AuthError), | ||
12 | Generic, | 12 | Generic, |
13 | Server(Box<dyn Error>), | 13 | Auth(AuthError), |
14 | Ping(surge_ping::SurgeError), | ||
15 | DB(sqlx::Error), | ||
16 | IpParse(<std::net::IpAddr as std::str::FromStr>::Err), | ||
17 | BufferParse(std::num::ParseIntError), | ||
18 | Broadcast(io::Error), | ||
19 | Axum(axum::Error) | ||
14 | } | 20 | } |
15 | 21 | ||
16 | impl IntoResponse for WebolError { | 22 | impl IntoResponse for WebolError { |
17 | fn into_response(self) -> Response { | 23 | fn into_response(self) -> Response { |
18 | let (status, error_message) = match self { | 24 | let (status, error_message) = match self { |
19 | WebolError::Auth(err) => err.get(), | 25 | Self::Auth(err) => err.get(), |
20 | WebolError::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), | 26 | Self::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), |
21 | WebolError::Server(err) => { | 27 | Self::Ping(err) => { |
28 | error!("Ping: {}", err.source().unwrap()); | ||
29 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
30 | }, | ||
31 | Self::IpParse(err) => { | ||
32 | error!("server error: {}", err.to_string()); | ||
33 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
34 | }, | ||
35 | Self::DB(err) => { | ||
36 | error!("server error: {}", err.to_string()); | ||
37 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
38 | }, | ||
39 | Self::Broadcast(err) => { | ||
40 | error!("server error: {}", err.to_string()); | ||
41 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
42 | }, | ||
43 | Self::BufferParse(err) => { | ||
44 | error!("server error: {}", err.to_string()); | ||
45 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
46 | }, | ||
47 | Self::Axum(err) => { | ||
22 | error!("server error: {}", err.to_string()); | 48 | error!("server error: {}", err.to_string()); |
23 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | 49 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") |
24 | }, | 50 | }, |
25 | |||
26 | }; | 51 | }; |
27 | let body = Json(json!({ | 52 | let body = Json(json!({ |
28 | "error": error_message, | 53 | "error": error_message, |
diff --git a/src/main.rs b/src/main.rs index ce12cf6..9c31ec8 100644 --- a/src/main.rs +++ b/src/main.rs | |||
@@ -4,12 +4,14 @@ use axum::{Router, routing::post}; | |||
4 | use axum::routing::{get, put}; | 4 | use axum::routing::{get, put}; |
5 | use sqlx::PgPool; | 5 | use sqlx::PgPool; |
6 | use time::util::local_offset; | 6 | use time::util::local_offset; |
7 | use tokio::sync::mpsc::{self, Sender}; | ||
7 | use tracing::{info, level_filters::LevelFilter}; | 8 | use tracing::{info, level_filters::LevelFilter}; |
8 | use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; | 9 | use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; |
9 | use crate::config::SETTINGS; | 10 | use crate::config::SETTINGS; |
10 | use crate::db::init_db_pool; | 11 | use crate::db::init_db_pool; |
11 | use crate::routes::device::{get_device, post_device, put_device}; | 12 | use crate::routes::device::{get_device, post_device, put_device}; |
12 | use crate::routes::start::start; | 13 | use crate::routes::start::start; |
14 | use crate::services::ping::ws_ping; | ||
13 | 15 | ||
14 | mod auth; | 16 | mod auth; |
15 | mod config; | 17 | mod config; |
@@ -17,6 +19,7 @@ mod routes; | |||
17 | mod wol; | 19 | mod wol; |
18 | mod db; | 20 | mod db; |
19 | mod error; | 21 | mod error; |
22 | mod services; | ||
20 | 23 | ||
21 | #[tokio::main] | 24 | #[tokio::main] |
22 | async fn main() { | 25 | async fn main() { |
@@ -43,13 +46,23 @@ async fn main() { | |||
43 | let db = init_db_pool().await; | 46 | let db = init_db_pool().await; |
44 | sqlx::migrate!().run(&db).await.unwrap(); | 47 | sqlx::migrate!().run(&db).await.unwrap(); |
45 | 48 | ||
46 | let shared_state = Arc::new(AppState { db }); | 49 | let (tx, mut rx) = mpsc::channel(32); |
50 | |||
51 | // FIXME: once_cell? or just static mutable | ||
52 | tokio::spawn( async move { | ||
53 | while let Some(message) = rx.recv().await { | ||
54 | println!("GOT = {}", message); | ||
55 | } | ||
56 | }); | ||
57 | |||
58 | let shared_state = Arc::new(AppState { db, ping_send: tx }); | ||
47 | 59 | ||
48 | let app = Router::new() | 60 | let app = Router::new() |
49 | .route("/start", post(start)) | 61 | .route("/start", post(start)) |
50 | .route("/device", get(get_device)) | 62 | .route("/device", get(get_device)) |
51 | .route("/device", put(put_device)) | 63 | .route("/device", put(put_device)) |
52 | .route("/device", post(post_device)) | 64 | .route("/device", post(post_device)) |
65 | .route("/status", get(ws_ping)) | ||
53 | .with_state(shared_state); | 66 | .with_state(shared_state); |
54 | 67 | ||
55 | let addr = SETTINGS.get_string("serveraddr").unwrap_or("0.0.0.0:7229".to_string()); | 68 | let addr = SETTINGS.get_string("serveraddr").unwrap_or("0.0.0.0:7229".to_string()); |
@@ -61,5 +74,7 @@ async fn main() { | |||
61 | } | 74 | } |
62 | 75 | ||
63 | pub struct AppState { | 76 | pub struct AppState { |
64 | db: PgPool | 77 | db: PgPool, |
78 | ping_send: Sender<String>, | ||
79 | // ping_receive: Receiver<String> | ||
65 | } | 80 | } |
diff --git a/src/routes/device.rs b/src/routes/device.rs index 025c7d0..248d1e0 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs | |||
@@ -21,7 +21,7 @@ pub async fn get_device(State(state): State<Arc<crate::AppState>>, headers: Head | |||
21 | WHERE id = $1; | 21 | WHERE id = $1; |
22 | "#, | 22 | "#, |
23 | payload.id | 23 | payload.id |
24 | ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; | 24 | ).fetch_one(&state.db).await.map_err(WebolError::DB)?; |
25 | 25 | ||
26 | Ok(Json(json!(device))) | 26 | Ok(Json(json!(device))) |
27 | } else { | 27 | } else { |
@@ -46,7 +46,7 @@ pub async fn put_device(State(state): State<Arc<crate::AppState>>, headers: Head | |||
46 | payload.id, | 46 | payload.id, |
47 | payload.mac, | 47 | payload.mac, |
48 | payload.broadcast_addr | 48 | payload.broadcast_addr |
49 | ).execute(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; | 49 | ).execute(&state.db).await.map_err(WebolError::DB)?; |
50 | 50 | ||
51 | Ok(Json(json!(PutDeviceResponse { success: true }))) | 51 | Ok(Json(json!(PutDeviceResponse { success: true }))) |
52 | } else { | 52 | } else { |
@@ -80,7 +80,7 @@ pub async fn post_device(State(state): State<Arc<crate::AppState>>, headers: Hea | |||
80 | payload.mac, | 80 | payload.mac, |
81 | payload.broadcast_addr, | 81 | payload.broadcast_addr, |
82 | payload.id | 82 | payload.id |
83 | ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; | 83 | ).fetch_one(&state.db).await.map_err(WebolError::DB)?; |
84 | 84 | ||
85 | Ok(Json(json!(device))) | 85 | Ok(Json(json!(device))) |
86 | } else { | 86 | } else { |
@@ -93,4 +93,4 @@ pub struct PostDevicePayload { | |||
93 | id: String, | 93 | id: String, |
94 | mac: String, | 94 | mac: String, |
95 | broadcast_addr: String, | 95 | broadcast_addr: String, |
96 | } \ No newline at end of file | 96 | } |
diff --git a/src/routes/start.rs b/src/routes/start.rs index 163d58c..b45fe5b 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs | |||
@@ -14,7 +14,8 @@ use crate::error::WebolError; | |||
14 | pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<StartPayload>) -> Result<Json<Value>, WebolError> { | 14 | pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<StartPayload>) -> Result<Json<Value>, WebolError> { |
15 | info!("POST request"); | 15 | info!("POST request"); |
16 | let secret = headers.get("authorization"); | 16 | let secret = headers.get("authorization"); |
17 | if auth(secret).map_err(WebolError::Auth)? { | 17 | let authorized = auth(secret).map_err(WebolError::Auth)?; |
18 | if authorized { | ||
18 | let device = sqlx::query_as!( | 19 | let device = sqlx::query_as!( |
19 | Device, | 20 | Device, |
20 | r#" | 21 | r#" |
@@ -23,7 +24,7 @@ pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap | |||
23 | WHERE id = $1; | 24 | WHERE id = $1; |
24 | "#, | 25 | "#, |
25 | payload.id | 26 | payload.id |
26 | ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; | 27 | ).fetch_one(&state.db).await.map_err(WebolError::DB)?; |
27 | 28 | ||
28 | info!("starting {}", device.id); | 29 | info!("starting {}", device.id); |
29 | 30 | ||
@@ -32,10 +33,14 @@ pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap | |||
32 | .unwrap_or("0.0.0.0:1111".to_string()); | 33 | .unwrap_or("0.0.0.0:1111".to_string()); |
33 | 34 | ||
34 | let _ = send_packet( | 35 | let _ = send_packet( |
35 | &bind_addr.parse().map_err(|err| WebolError::Server(Box::new(err)))?, | 36 | &bind_addr.parse().map_err(WebolError::IpParse)?, |
36 | &device.broadcast_addr.parse().map_err(|err| WebolError::Server(Box::new(err)))?, | 37 | &device.broadcast_addr.parse().map_err(WebolError::IpParse)?, |
37 | create_buffer(&device.mac).map_err(|err| WebolError::Server(Box::new(err)))? | 38 | create_buffer(&device.mac)? |
38 | ).map_err(|err| WebolError::Server(Box::new(err))); | 39 | )?; |
40 | |||
41 | if payload.ping.is_some_and(|ping| ping) { | ||
42 | tokio::spawn(async move {crate::services::ping::spawn(state.ping_send.clone()).await}); | ||
43 | } | ||
39 | Ok(Json(json!(StartResponse { id: device.id, boot: true }))) | 44 | Ok(Json(json!(StartResponse { id: device.id, boot: true }))) |
40 | } else { | 45 | } else { |
41 | Err(WebolError::Generic) | 46 | Err(WebolError::Generic) |
@@ -45,11 +50,11 @@ pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap | |||
45 | #[derive(Deserialize)] | 50 | #[derive(Deserialize)] |
46 | pub struct StartPayload { | 51 | pub struct StartPayload { |
47 | id: String, | 52 | id: String, |
48 | _test: Option<bool>, | 53 | ping: Option<bool>, |
49 | } | 54 | } |
50 | 55 | ||
51 | #[derive(Serialize)] | 56 | #[derive(Serialize)] |
52 | struct StartResponse { | 57 | struct StartResponse { |
53 | id: String, | 58 | id: String, |
54 | boot: bool, | 59 | boot: bool, |
55 | } \ No newline at end of file | 60 | } |
diff --git a/src/services/mod.rs b/src/services/mod.rs new file mode 100644 index 0000000..a766209 --- /dev/null +++ b/src/services/mod.rs | |||
@@ -0,0 +1 @@ | |||
pub mod ping; | |||
diff --git a/src/services/ping.rs b/src/services/ping.rs new file mode 100644 index 0000000..6e710ec --- /dev/null +++ b/src/services/ping.rs | |||
@@ -0,0 +1,53 @@ | |||
1 | use std::sync::Arc; | ||
2 | |||
3 | use axum::{extract::{WebSocketUpgrade, ws::WebSocket, State}, response::Response}; | ||
4 | use tokio::sync::mpsc::Sender; | ||
5 | use tracing::{debug, error}; | ||
6 | |||
7 | use crate::{error::WebolError, AppState}; | ||
8 | |||
9 | pub async fn spawn(tx: Sender<String>) -> Result<(), WebolError> { | ||
10 | let payload = [0; 8]; | ||
11 | |||
12 | let mut cont = true; | ||
13 | while cont { | ||
14 | let ping = surge_ping::ping( | ||
15 | "192.168.178.28".parse().map_err(WebolError::IpParse)?, | ||
16 | &payload | ||
17 | ).await; | ||
18 | |||
19 | if let Err(ping) = ping { | ||
20 | cont = matches!(ping, surge_ping::SurgeError::Timeout { .. }); | ||
21 | |||
22 | debug!("{}", cont); | ||
23 | |||
24 | if !cont { | ||
25 | return Err(ping).map_err(WebolError::Ping) | ||
26 | } | ||
27 | |||
28 | } else { | ||
29 | let (_, duration) = ping.unwrap(); | ||
30 | debug!("Ping took {:?}", duration); | ||
31 | cont = false; | ||
32 | // FIXME: remove unwrap | ||
33 | tx.send("Got ping".to_string()).await.unwrap(); | ||
34 | }; | ||
35 | } | ||
36 | |||
37 | Ok(()) | ||
38 | } | ||
39 | |||
40 | pub async fn ws_ping(ws: WebSocketUpgrade, State(_state): State<Arc<AppState>>) -> Response { | ||
41 | ws.on_upgrade(handle_socket) | ||
42 | } | ||
43 | |||
44 | // FIXME: Handle commands through enum | ||
45 | async fn handle_socket(mut socket: WebSocket) { | ||
46 | // TODO: Understand Cow | ||
47 | |||
48 | // match socket.send(axum::extract::ws::Message::Close(Some(CloseFrame { code: 4000, reason: Cow::Owned("started".to_owned()) }))).await.map_err(WebolError::Axum) { | ||
49 | match socket.send(axum::extract::ws::Message::Text("started".to_string())).await.map_err(WebolError::Axum) { | ||
50 | Ok(..) => (), | ||
51 | Err(err) => { error!("Server Error: {:?}", err) } | ||
52 | }; | ||
53 | } | ||
@@ -1,16 +1,17 @@ | |||
1 | use std::net::{SocketAddr, UdpSocket}; | 1 | use std::net::{SocketAddr, UdpSocket}; |
2 | use std::num::ParseIntError; | 2 | |
3 | use crate::error::WebolError; | ||
3 | 4 | ||
4 | /// Creates the magic packet from a mac address | 5 | /// Creates the magic packet from a mac address |
5 | /// | 6 | /// |
6 | /// # Panics | 7 | /// # Panics |
7 | /// | 8 | /// |
8 | /// Panics if `mac_addr` is an invalid mac | 9 | /// Panics if `mac_addr` is an invalid mac |
9 | pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, ParseIntError> { | 10 | pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, WebolError> { |
10 | let mut mac = Vec::new(); | 11 | let mut mac = Vec::new(); |
11 | let sp = mac_addr.split(':'); | 12 | let sp = mac_addr.split(':'); |
12 | for f in sp { | 13 | for f in sp { |
13 | mac.push(u8::from_str_radix(f, 16)?); | 14 | mac.push(u8::from_str_radix(f, 16).map_err(WebolError::BufferParse)?) |
14 | }; | 15 | }; |
15 | let mut buf = vec![255; 6]; | 16 | let mut buf = vec![255; 6]; |
16 | for _ in 0..16 { | 17 | for _ in 0..16 { |
@@ -22,8 +23,8 @@ pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, ParseIntError> { | |||
22 | } | 23 | } |
23 | 24 | ||
24 | /// Sends a buffer on UDP broadcast | 25 | /// Sends a buffer on UDP broadcast |
25 | pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: Vec<u8>) -> Result<usize, std::io::Error> { | 26 | pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: Vec<u8>) -> Result<usize, WebolError> { |
26 | let socket = UdpSocket::bind(bind_addr)?; | 27 | let socket = UdpSocket::bind(bind_addr).map_err(WebolError::Broadcast)?; |
27 | socket.set_broadcast(true)?; | 28 | socket.set_broadcast(true).map_err(WebolError::Broadcast)?; |
28 | socket.send_to(&buffer, broadcast_addr) | 29 | socket.send_to(&buffer, broadcast_addr).map_err(WebolError::Broadcast) |
29 | } \ No newline at end of file | 30 | } |