diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/auth.rs | 21 | ||||
-rw-r--r-- | src/db.rs | 5 | ||||
-rw-r--r-- | src/error.rs | 32 | ||||
-rw-r--r-- | src/main.rs | 18 | ||||
-rw-r--r-- | src/routes/device.rs | 34 | ||||
-rw-r--r-- | src/routes/mod.rs | 3 | ||||
-rw-r--r-- | src/routes/start.rs | 39 | ||||
-rw-r--r-- | src/routes/status.rs | 10 | ||||
-rw-r--r-- | src/services/mod.rs | 1 | ||||
-rw-r--r-- | src/services/ping.rs | 118 | ||||
-rw-r--r-- | src/wol.rs | 17 |
11 files changed, 243 insertions, 55 deletions
diff --git a/src/auth.rs b/src/auth.rs index 0fffa60..e4b1c2f 100644 --- a/src/auth.rs +++ b/src/auth.rs | |||
@@ -1,8 +1,8 @@ | |||
1 | use std::error::Error; | ||
2 | use axum::headers::HeaderValue; | 1 | use axum::headers::HeaderValue; |
3 | use axum::http::StatusCode; | 2 | use axum::http::StatusCode; |
3 | use axum::http::header::ToStrError; | ||
4 | use tracing::{debug, error, trace}; | 4 | use tracing::{debug, error, trace}; |
5 | use crate::auth::AuthError::{MissingSecret, ServerError, WrongSecret}; | 5 | use crate::auth::AuthError::{MissingSecret, WrongSecret}; |
6 | use crate::config::SETTINGS; | 6 | use crate::config::SETTINGS; |
7 | 7 | ||
8 | pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> { | 8 | pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> { |
@@ -11,8 +11,8 @@ pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> { | |||
11 | trace!("value exists"); | 11 | trace!("value exists"); |
12 | let key = SETTINGS | 12 | let key = SETTINGS |
13 | .get_string("apikey") | 13 | .get_string("apikey") |
14 | .map_err(|err| ServerError(Box::new(err)))?; | 14 | .map_err(AuthError::Config)?; |
15 | if value.to_str().map_err(|err| ServerError(Box::new(err)))? == key.as_str() { | 15 | if value.to_str().map_err(AuthError::HeaderToStr)? == key.as_str() { |
16 | debug!("successful auth"); | 16 | debug!("successful auth"); |
17 | Ok(true) | 17 | Ok(true) |
18 | } else { | 18 | } else { |
@@ -29,15 +29,20 @@ pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> { | |||
29 | pub enum AuthError { | 29 | pub enum AuthError { |
30 | WrongSecret, | 30 | WrongSecret, |
31 | MissingSecret, | 31 | MissingSecret, |
32 | ServerError(Box<dyn Error>), | 32 | Config(config::ConfigError), |
33 | HeaderToStr(ToStrError) | ||
33 | } | 34 | } |
34 | 35 | ||
35 | impl AuthError { | 36 | impl AuthError { |
36 | pub fn get(self) -> (StatusCode, &'static str) { | 37 | pub fn get(self) -> (StatusCode, &'static str) { |
37 | match self { | 38 | match self { |
38 | AuthError::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"), | 39 | Self::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"), |
39 | AuthError::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"), | 40 | Self::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"), |
40 | AuthError::ServerError(err) => { | 41 | Self::Config(err) => { |
42 | error!("server error: {}", err.to_string()); | ||
43 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
44 | }, | ||
45 | Self::HeaderToStr(err) => { | ||
41 | error!("server error: {}", err.to_string()); | 46 | error!("server error: {}", err.to_string()); |
42 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | 47 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") |
43 | }, | 48 | }, |
@@ -8,11 +8,12 @@ use tracing::{debug, info}; | |||
8 | #[cfg(not(debug_assertions))] | 8 | #[cfg(not(debug_assertions))] |
9 | use crate::config::SETTINGS; | 9 | use crate::config::SETTINGS; |
10 | 10 | ||
11 | #[derive(Serialize)] | 11 | #[derive(Serialize, Debug)] |
12 | pub struct Device { | 12 | pub struct Device { |
13 | pub id: String, | 13 | pub id: String, |
14 | pub mac: String, | 14 | pub mac: String, |
15 | pub broadcast_addr: String | 15 | pub broadcast_addr: String, |
16 | pub ip: String | ||
16 | } | 17 | } |
17 | 18 | ||
18 | pub async fn init_db_pool() -> PgPool { | 19 | pub async fn init_db_pool() -> PgPool { |
diff --git a/src/error.rs b/src/error.rs index db2fc86..5b82534 100644 --- a/src/error.rs +++ b/src/error.rs | |||
@@ -1,4 +1,4 @@ | |||
1 | use std::error::Error; | 1 | use std::io; |
2 | use axum::http::StatusCode; | 2 | use axum::http::StatusCode; |
3 | use axum::Json; | 3 | use axum::Json; |
4 | use axum::response::{IntoResponse, Response}; | 4 | use axum::response::{IntoResponse, Response}; |
@@ -8,25 +8,41 @@ use crate::auth::AuthError; | |||
8 | 8 | ||
9 | #[derive(Debug)] | 9 | #[derive(Debug)] |
10 | pub enum WebolError { | 10 | pub enum WebolError { |
11 | Auth(AuthError), | ||
12 | Generic, | 11 | Generic, |
13 | Server(Box<dyn Error>), | 12 | Auth(AuthError), |
13 | DB(sqlx::Error), | ||
14 | IpParse(<std::net::IpAddr as std::str::FromStr>::Err), | ||
15 | BufferParse(std::num::ParseIntError), | ||
16 | Broadcast(io::Error), | ||
14 | } | 17 | } |
15 | 18 | ||
16 | impl IntoResponse for WebolError { | 19 | impl IntoResponse for WebolError { |
17 | fn into_response(self) -> Response { | 20 | fn into_response(self) -> Response { |
18 | let (status, error_message) = match self { | 21 | let (status, error_message) = match self { |
19 | WebolError::Auth(err) => err.get(), | 22 | Self::Auth(err) => { |
20 | WebolError::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), | 23 | err.get() |
21 | WebolError::Server(err) => { | 24 | }, |
25 | Self::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), | ||
26 | Self::IpParse(err) => { | ||
27 | error!("server error: {}", err.to_string()); | ||
28 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
29 | }, | ||
30 | Self::DB(err) => { | ||
31 | error!("server error: {}", err.to_string()); | ||
32 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
33 | }, | ||
34 | Self::Broadcast(err) => { | ||
35 | error!("server error: {}", err.to_string()); | ||
36 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | ||
37 | }, | ||
38 | Self::BufferParse(err) => { | ||
22 | error!("server error: {}", err.to_string()); | 39 | error!("server error: {}", err.to_string()); |
23 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") | 40 | (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") |
24 | }, | 41 | }, |
25 | |||
26 | }; | 42 | }; |
27 | let body = Json(json!({ | 43 | let body = Json(json!({ |
28 | "error": error_message, | 44 | "error": error_message, |
29 | })); | 45 | })); |
30 | (status, body).into_response() | 46 | (status, body).into_response() |
31 | } | 47 | } |
32 | } | 48 | } \ No newline at end of file |
diff --git a/src/main.rs b/src/main.rs index ce12cf6..e96b736 100644 --- a/src/main.rs +++ b/src/main.rs | |||
@@ -2,14 +2,18 @@ use std::env; | |||
2 | use std::sync::Arc; | 2 | use std::sync::Arc; |
3 | use axum::{Router, routing::post}; | 3 | use axum::{Router, routing::post}; |
4 | use axum::routing::{get, put}; | 4 | use axum::routing::{get, put}; |
5 | use dashmap::DashMap; | ||
5 | use sqlx::PgPool; | 6 | use sqlx::PgPool; |
6 | use time::util::local_offset; | 7 | use time::util::local_offset; |
8 | use tokio::sync::broadcast::{channel, Sender}; | ||
7 | use tracing::{info, level_filters::LevelFilter}; | 9 | use tracing::{info, level_filters::LevelFilter}; |
8 | use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; | 10 | use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; |
9 | use crate::config::SETTINGS; | 11 | use crate::config::SETTINGS; |
10 | use crate::db::init_db_pool; | 12 | use crate::db::init_db_pool; |
11 | use crate::routes::device::{get_device, post_device, put_device}; | 13 | use crate::routes::device::{get_device, post_device, put_device}; |
12 | use crate::routes::start::start; | 14 | use crate::routes::start::start; |
15 | use crate::routes::status::status; | ||
16 | use crate::services::ping::{BroadcastCommands, PingMap}; | ||
13 | 17 | ||
14 | mod auth; | 18 | mod auth; |
15 | mod config; | 19 | mod config; |
@@ -17,6 +21,7 @@ mod routes; | |||
17 | mod wol; | 21 | mod wol; |
18 | mod db; | 22 | mod db; |
19 | mod error; | 23 | mod error; |
24 | mod services; | ||
20 | 25 | ||
21 | #[tokio::main] | 26 | #[tokio::main] |
22 | async fn main() { | 27 | async fn main() { |
@@ -43,13 +48,18 @@ async fn main() { | |||
43 | let db = init_db_pool().await; | 48 | let db = init_db_pool().await; |
44 | sqlx::migrate!().run(&db).await.unwrap(); | 49 | sqlx::migrate!().run(&db).await.unwrap(); |
45 | 50 | ||
46 | let shared_state = Arc::new(AppState { db }); | 51 | let (tx, _) = channel(32); |
52 | |||
53 | let ping_map: PingMap = DashMap::new(); | ||
54 | |||
55 | let shared_state = Arc::new(AppState { db, ping_send: tx, ping_map }); | ||
47 | 56 | ||
48 | let app = Router::new() | 57 | let app = Router::new() |
49 | .route("/start", post(start)) | 58 | .route("/start", post(start)) |
50 | .route("/device", get(get_device)) | 59 | .route("/device", get(get_device)) |
51 | .route("/device", put(put_device)) | 60 | .route("/device", put(put_device)) |
52 | .route("/device", post(post_device)) | 61 | .route("/device", post(post_device)) |
62 | .route("/status", get(status)) | ||
53 | .with_state(shared_state); | 63 | .with_state(shared_state); |
54 | 64 | ||
55 | let addr = SETTINGS.get_string("serveraddr").unwrap_or("0.0.0.0:7229".to_string()); | 65 | let addr = SETTINGS.get_string("serveraddr").unwrap_or("0.0.0.0:7229".to_string()); |
@@ -61,5 +71,7 @@ async fn main() { | |||
61 | } | 71 | } |
62 | 72 | ||
63 | pub struct AppState { | 73 | pub struct AppState { |
64 | db: PgPool | 74 | db: PgPool, |
65 | } | 75 | ping_send: Sender<BroadcastCommands>, |
76 | ping_map: PingMap, | ||
77 | } \ No newline at end of file | ||
diff --git a/src/routes/device.rs b/src/routes/device.rs index 025c7d0..1eeff0b 100644 --- a/src/routes/device.rs +++ b/src/routes/device.rs | |||
@@ -4,24 +4,26 @@ use axum::headers::HeaderMap; | |||
4 | use axum::Json; | 4 | use axum::Json; |
5 | use serde::{Deserialize, Serialize}; | 5 | use serde::{Deserialize, Serialize}; |
6 | use serde_json::{json, Value}; | 6 | use serde_json::{json, Value}; |
7 | use tracing::info; | 7 | use tracing::{debug, info}; |
8 | use crate::auth::auth; | 8 | use crate::auth::auth; |
9 | use crate::db::Device; | 9 | use crate::db::Device; |
10 | use crate::error::WebolError; | 10 | use crate::error::WebolError; |
11 | 11 | ||
12 | pub async fn get_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<GetDevicePayload>) -> Result<Json<Value>, WebolError> { | 12 | pub async fn get_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<GetDevicePayload>) -> Result<Json<Value>, WebolError> { |
13 | info!("GET request"); | 13 | info!("add device {}", payload.id); |
14 | let secret = headers.get("authorization"); | 14 | let secret = headers.get("authorization"); |
15 | if auth(secret).map_err(WebolError::Auth)? { | 15 | if auth(secret).map_err(WebolError::Auth)? { |
16 | let device = sqlx::query_as!( | 16 | let device = sqlx::query_as!( |
17 | Device, | 17 | Device, |
18 | r#" | 18 | r#" |
19 | SELECT id, mac, broadcast_addr | 19 | SELECT id, mac, broadcast_addr, ip |
20 | FROM devices | 20 | FROM devices |
21 | WHERE id = $1; | 21 | WHERE id = $1; |
22 | "#, | 22 | "#, |
23 | payload.id | 23 | payload.id |
24 | ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; | 24 | ).fetch_one(&state.db).await.map_err(WebolError::DB)?; |
25 | |||
26 | debug!("got device {:?}", device); | ||
25 | 27 | ||
26 | Ok(Json(json!(device))) | 28 | Ok(Json(json!(device))) |
27 | } else { | 29 | } else { |
@@ -35,18 +37,19 @@ pub struct GetDevicePayload { | |||
35 | } | 37 | } |
36 | 38 | ||
37 | pub async fn put_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PutDevicePayload>) -> Result<Json<Value>, WebolError> { | 39 | pub async fn put_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PutDevicePayload>) -> Result<Json<Value>, WebolError> { |
38 | info!("PUT request"); | 40 | info!("add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); |
39 | let secret = headers.get("authorization"); | 41 | let secret = headers.get("authorization"); |
40 | if auth(secret).map_err(WebolError::Auth)? { | 42 | if auth(secret).map_err(WebolError::Auth)? { |
41 | sqlx::query!( | 43 | sqlx::query!( |
42 | r#" | 44 | r#" |
43 | INSERT INTO devices (id, mac, broadcast_addr) | 45 | INSERT INTO devices (id, mac, broadcast_addr, ip) |
44 | VALUES ($1, $2, $3); | 46 | VALUES ($1, $2, $3, $4); |
45 | "#, | 47 | "#, |
46 | payload.id, | 48 | payload.id, |
47 | payload.mac, | 49 | payload.mac, |
48 | payload.broadcast_addr | 50 | payload.broadcast_addr, |
49 | ).execute(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; | 51 | payload.ip |
52 | ).execute(&state.db).await.map_err(WebolError::DB)?; | ||
50 | 53 | ||
51 | Ok(Json(json!(PutDeviceResponse { success: true }))) | 54 | Ok(Json(json!(PutDeviceResponse { success: true }))) |
52 | } else { | 55 | } else { |
@@ -59,6 +62,7 @@ pub struct PutDevicePayload { | |||
59 | id: String, | 62 | id: String, |
60 | mac: String, | 63 | mac: String, |
61 | broadcast_addr: String, | 64 | broadcast_addr: String, |
65 | ip: String | ||
62 | } | 66 | } |
63 | 67 | ||
64 | #[derive(Serialize)] | 68 | #[derive(Serialize)] |
@@ -67,20 +71,21 @@ pub struct PutDeviceResponse { | |||
67 | } | 71 | } |
68 | 72 | ||
69 | pub async fn post_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PostDevicePayload>) -> Result<Json<Value>, WebolError> { | 73 | pub async fn post_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PostDevicePayload>) -> Result<Json<Value>, WebolError> { |
70 | info!("POST request"); | 74 | info!("edit device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); |
71 | let secret = headers.get("authorization"); | 75 | let secret = headers.get("authorization"); |
72 | if auth(secret).map_err(WebolError::Auth)? { | 76 | if auth(secret).map_err(WebolError::Auth)? { |
73 | let device = sqlx::query_as!( | 77 | let device = sqlx::query_as!( |
74 | Device, | 78 | Device, |
75 | r#" | 79 | r#" |
76 | UPDATE devices | 80 | UPDATE devices |
77 | SET mac = $1, broadcast_addr = $2 WHERE id = $3 | 81 | SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 |
78 | RETURNING id, mac, broadcast_addr; | 82 | RETURNING id, mac, broadcast_addr, ip; |
79 | "#, | 83 | "#, |
80 | payload.mac, | 84 | payload.mac, |
81 | payload.broadcast_addr, | 85 | payload.broadcast_addr, |
86 | payload.ip, | ||
82 | payload.id | 87 | payload.id |
83 | ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; | 88 | ).fetch_one(&state.db).await.map_err(WebolError::DB)?; |
84 | 89 | ||
85 | Ok(Json(json!(device))) | 90 | Ok(Json(json!(device))) |
86 | } else { | 91 | } else { |
@@ -93,4 +98,5 @@ pub struct PostDevicePayload { | |||
93 | id: String, | 98 | id: String, |
94 | mac: String, | 99 | mac: String, |
95 | broadcast_addr: String, | 100 | broadcast_addr: String, |
96 | } \ No newline at end of file | 101 | ip: String, |
102 | } | ||
diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 12fbfab..d5ab0d6 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs | |||
@@ -1,2 +1,3 @@ | |||
1 | pub mod start; | 1 | pub mod start; |
2 | pub mod device; \ No newline at end of file | 2 | pub mod device; |
3 | pub mod status; \ No newline at end of file | ||
diff --git a/src/routes/start.rs b/src/routes/start.rs index 163d58c..271f924 100644 --- a/src/routes/start.rs +++ b/src/routes/start.rs | |||
@@ -4,26 +4,30 @@ use serde::{Deserialize, Serialize}; | |||
4 | use std::sync::Arc; | 4 | use std::sync::Arc; |
5 | use axum::extract::State; | 5 | use axum::extract::State; |
6 | use serde_json::{json, Value}; | 6 | use serde_json::{json, Value}; |
7 | use tracing::info; | 7 | use tracing::{debug, info}; |
8 | use uuid::Uuid; | ||
8 | use crate::auth::auth; | 9 | use crate::auth::auth; |
9 | use crate::config::SETTINGS; | 10 | use crate::config::SETTINGS; |
10 | use crate::wol::{create_buffer, send_packet}; | 11 | use crate::wol::{create_buffer, send_packet}; |
11 | use crate::db::Device; | 12 | use crate::db::Device; |
12 | use crate::error::WebolError; | 13 | use crate::error::WebolError; |
14 | use crate::services::ping::PingValue; | ||
13 | 15 | ||
16 | #[axum_macros::debug_handler] | ||
14 | pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<StartPayload>) -> Result<Json<Value>, WebolError> { | 17 | pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<StartPayload>) -> Result<Json<Value>, WebolError> { |
15 | info!("POST request"); | 18 | info!("POST request"); |
16 | let secret = headers.get("authorization"); | 19 | let secret = headers.get("authorization"); |
17 | if auth(secret).map_err(WebolError::Auth)? { | 20 | let authorized = auth(secret).map_err(WebolError::Auth)?; |
21 | if authorized { | ||
18 | let device = sqlx::query_as!( | 22 | let device = sqlx::query_as!( |
19 | Device, | 23 | Device, |
20 | r#" | 24 | r#" |
21 | SELECT id, mac, broadcast_addr | 25 | SELECT id, mac, broadcast_addr, ip |
22 | FROM devices | 26 | FROM devices |
23 | WHERE id = $1; | 27 | WHERE id = $1; |
24 | "#, | 28 | "#, |
25 | payload.id | 29 | payload.id |
26 | ).fetch_one(&state.db).await.map_err(|err| WebolError::Server(Box::new(err)))?; | 30 | ).fetch_one(&state.db).await.map_err(WebolError::DB)?; |
27 | 31 | ||
28 | info!("starting {}", device.id); | 32 | info!("starting {}", device.id); |
29 | 33 | ||
@@ -32,11 +36,23 @@ pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap | |||
32 | .unwrap_or("0.0.0.0:1111".to_string()); | 36 | .unwrap_or("0.0.0.0:1111".to_string()); |
33 | 37 | ||
34 | let _ = send_packet( | 38 | let _ = send_packet( |
35 | &bind_addr.parse().map_err(|err| WebolError::Server(Box::new(err)))?, | 39 | &bind_addr.parse().map_err(WebolError::IpParse)?, |
36 | &device.broadcast_addr.parse().map_err(|err| WebolError::Server(Box::new(err)))?, | 40 | &device.broadcast_addr.parse().map_err(WebolError::IpParse)?, |
37 | create_buffer(&device.mac).map_err(|err| WebolError::Server(Box::new(err)))? | 41 | create_buffer(&device.mac)? |
38 | ).map_err(|err| WebolError::Server(Box::new(err))); | 42 | )?; |
39 | Ok(Json(json!(StartResponse { id: device.id, boot: true }))) | 43 | |
44 | let uuid = if payload.ping.is_some_and(|ping| ping) { | ||
45 | let uuid_gen = Uuid::new_v4().to_string(); | ||
46 | let uuid_genc = uuid_gen.clone(); | ||
47 | tokio::spawn(async move { | ||
48 | debug!("init ping service"); | ||
49 | state.ping_map.insert(uuid_gen.clone(), PingValue { ip: device.ip.clone(), online: false }); | ||
50 | |||
51 | crate::services::ping::spawn(state.ping_send.clone(), device.ip, uuid_gen.clone(), &state.ping_map).await | ||
52 | }); | ||
53 | Some(uuid_genc) | ||
54 | } else { None }; | ||
55 | Ok(Json(json!(StartResponse { id: device.id, boot: true, uuid }))) | ||
40 | } else { | 56 | } else { |
41 | Err(WebolError::Generic) | 57 | Err(WebolError::Generic) |
42 | } | 58 | } |
@@ -45,11 +61,12 @@ pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap | |||
45 | #[derive(Deserialize)] | 61 | #[derive(Deserialize)] |
46 | pub struct StartPayload { | 62 | pub struct StartPayload { |
47 | id: String, | 63 | id: String, |
48 | _test: Option<bool>, | 64 | ping: Option<bool>, |
49 | } | 65 | } |
50 | 66 | ||
51 | #[derive(Serialize)] | 67 | #[derive(Serialize)] |
52 | struct StartResponse { | 68 | struct StartResponse { |
53 | id: String, | 69 | id: String, |
54 | boot: bool, | 70 | boot: bool, |
55 | } \ No newline at end of file | 71 | uuid: Option<String>, |
72 | } | ||
diff --git a/src/routes/status.rs b/src/routes/status.rs new file mode 100644 index 0000000..45f3e51 --- /dev/null +++ b/src/routes/status.rs | |||
@@ -0,0 +1,10 @@ | |||
1 | use std::sync::Arc; | ||
2 | use axum::extract::{State, WebSocketUpgrade}; | ||
3 | use axum::response::Response; | ||
4 | use crate::AppState; | ||
5 | use crate::services::ping::status_websocket; | ||
6 | |||
7 | #[axum_macros::debug_handler] | ||
8 | pub async fn status(State(state): State<Arc<AppState>>, ws: WebSocketUpgrade) -> Response { | ||
9 | ws.on_upgrade(move |socket| status_websocket(socket, state)) | ||
10 | } \ No newline at end of file | ||
diff --git a/src/services/mod.rs b/src/services/mod.rs new file mode 100644 index 0000000..a766209 --- /dev/null +++ b/src/services/mod.rs | |||
@@ -0,0 +1 @@ | |||
pub mod ping; | |||
diff --git a/src/services/ping.rs b/src/services/ping.rs new file mode 100644 index 0000000..d900acb --- /dev/null +++ b/src/services/ping.rs | |||
@@ -0,0 +1,118 @@ | |||
1 | use std::sync::Arc; | ||
2 | |||
3 | use axum::extract::{ws::WebSocket}; | ||
4 | use axum::extract::ws::Message; | ||
5 | use dashmap::DashMap; | ||
6 | use time::{Duration, Instant}; | ||
7 | use tokio::sync::broadcast::{Sender}; | ||
8 | use tracing::{debug, error, trace}; | ||
9 | use crate::AppState; | ||
10 | use crate::config::SETTINGS; | ||
11 | |||
12 | pub type PingMap = DashMap<String, PingValue>; | ||
13 | |||
14 | #[derive(Debug, Clone)] | ||
15 | pub struct PingValue { | ||
16 | pub ip: String, | ||
17 | pub online: bool | ||
18 | } | ||
19 | |||
20 | pub async fn spawn(tx: Sender<BroadcastCommands>, ip: String, uuid: String, ping_map: &PingMap) { | ||
21 | let timer = Instant::now(); | ||
22 | let payload = [0; 8]; | ||
23 | |||
24 | let mut cont = true; | ||
25 | while cont { | ||
26 | let ping = surge_ping::ping( | ||
27 | ip.parse().expect("bad ip"), | ||
28 | &payload | ||
29 | ).await; | ||
30 | |||
31 | if let Err(ping) = ping { | ||
32 | cont = matches!(ping, surge_ping::SurgeError::Timeout { .. }); | ||
33 | if !cont { | ||
34 | error!("{}", ping.to_string()); | ||
35 | } | ||
36 | if timer.elapsed() >= Duration::minutes(SETTINGS.get_int("pingtimeout").unwrap_or(10)) { | ||
37 | let _ = tx.send(BroadcastCommands::PingTimeout(uuid.clone())); | ||
38 | trace!("remove {} from ping_map after timeout", uuid); | ||
39 | ping_map.remove(&uuid); | ||
40 | cont = false; | ||
41 | } | ||
42 | } else { | ||
43 | let (_, duration) = ping.map_err(|err| error!("{}", err.to_string())).expect("fatal error"); | ||
44 | debug!("ping took {:?}", duration); | ||
45 | cont = false; | ||
46 | handle_broadcast_send(&tx, ip.clone(), ping_map, uuid.clone()).await; | ||
47 | }; | ||
48 | } | ||
49 | } | ||
50 | |||
51 | async fn handle_broadcast_send(tx: &Sender<BroadcastCommands>, ip: String, ping_map: &PingMap, uuid: String) { | ||
52 | debug!("send pingsuccess message"); | ||
53 | let _ = tx.send(BroadcastCommands::PingSuccess(uuid.clone())); | ||
54 | trace!("sent message"); | ||
55 | ping_map.insert(uuid.clone(), PingValue { ip: ip.clone(), online: true }); | ||
56 | trace!("updated ping_map"); | ||
57 | tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; | ||
58 | debug!("remove {} from ping_map after success", uuid); | ||
59 | ping_map.remove(&uuid); | ||
60 | } | ||
61 | |||
62 | #[derive(Clone, Debug)] | ||
63 | pub enum BroadcastCommands { | ||
64 | PingSuccess(String), | ||
65 | PingTimeout(String) | ||
66 | } | ||
67 | |||
68 | pub async fn status_websocket(mut socket: WebSocket, state: Arc<AppState>) { | ||
69 | trace!("wait for ws message (uuid)"); | ||
70 | let msg = socket.recv().await; | ||
71 | let uuid = msg.unwrap().unwrap().into_text().unwrap(); | ||
72 | |||
73 | trace!("Search for uuid: {:?}", uuid); | ||
74 | |||
75 | let device_exists = state.ping_map.contains_key(&uuid); | ||
76 | match device_exists { | ||
77 | true => { | ||
78 | let _ = socket.send(process_device(state.clone(), uuid).await).await; | ||
79 | }, | ||
80 | false => { | ||
81 | debug!("didn't find any device"); | ||
82 | let _ = socket.send(Message::Text(format!("notfound_{}", uuid))).await; | ||
83 | }, | ||
84 | }; | ||
85 | |||
86 | let _ = socket.close().await; | ||
87 | } | ||
88 | |||
89 | async fn process_device(state: Arc<AppState>, uuid: String) -> Message { | ||
90 | let pm = state.ping_map.clone().into_read_only(); | ||
91 | let device = pm.get(&uuid).expect("fatal error"); | ||
92 | debug!("got device: {} (online: {})", device.ip, device.online); | ||
93 | match device.online { | ||
94 | true => { | ||
95 | debug!("already started"); | ||
96 | Message::Text(format!("start_{}", uuid)) | ||
97 | }, | ||
98 | false => { | ||
99 | loop{ | ||
100 | trace!("wait for tx message"); | ||
101 | let message = state.ping_send.subscribe().recv().await.expect("fatal error"); | ||
102 | trace!("got message {:?}", message); | ||
103 | return match message { | ||
104 | BroadcastCommands::PingSuccess(msg_uuid) => { | ||
105 | if msg_uuid != uuid { continue; } | ||
106 | trace!("message == uuid success"); | ||
107 | Message::Text(format!("start_{}", uuid)) | ||
108 | }, | ||
109 | BroadcastCommands::PingTimeout(msg_uuid) => { | ||
110 | if msg_uuid != uuid { continue; } | ||
111 | trace!("message == uuid timeout"); | ||
112 | Message::Text(format!("timeout_{}", uuid)) | ||
113 | } | ||
114 | } | ||
115 | } | ||
116 | } | ||
117 | } | ||
118 | } \ No newline at end of file | ||
@@ -1,16 +1,17 @@ | |||
1 | use std::net::{SocketAddr, UdpSocket}; | 1 | use std::net::{SocketAddr, UdpSocket}; |
2 | use std::num::ParseIntError; | 2 | |
3 | use crate::error::WebolError; | ||
3 | 4 | ||
4 | /// Creates the magic packet from a mac address | 5 | /// Creates the magic packet from a mac address |
5 | /// | 6 | /// |
6 | /// # Panics | 7 | /// # Panics |
7 | /// | 8 | /// |
8 | /// Panics if `mac_addr` is an invalid mac | 9 | /// Panics if `mac_addr` is an invalid mac |
9 | pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, ParseIntError> { | 10 | pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, WebolError> { |
10 | let mut mac = Vec::new(); | 11 | let mut mac = Vec::new(); |
11 | let sp = mac_addr.split(':'); | 12 | let sp = mac_addr.split(':'); |
12 | for f in sp { | 13 | for f in sp { |
13 | mac.push(u8::from_str_radix(f, 16)?); | 14 | mac.push(u8::from_str_radix(f, 16).map_err(WebolError::BufferParse)?) |
14 | }; | 15 | }; |
15 | let mut buf = vec![255; 6]; | 16 | let mut buf = vec![255; 6]; |
16 | for _ in 0..16 { | 17 | for _ in 0..16 { |
@@ -22,8 +23,8 @@ pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, ParseIntError> { | |||
22 | } | 23 | } |
23 | 24 | ||
24 | /// Sends a buffer on UDP broadcast | 25 | /// Sends a buffer on UDP broadcast |
25 | pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: Vec<u8>) -> Result<usize, std::io::Error> { | 26 | pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: Vec<u8>) -> Result<usize, WebolError> { |
26 | let socket = UdpSocket::bind(bind_addr)?; | 27 | let socket = UdpSocket::bind(bind_addr).map_err(WebolError::Broadcast)?; |
27 | socket.set_broadcast(true)?; | 28 | socket.set_broadcast(true).map_err(WebolError::Broadcast)?; |
28 | socket.send_to(&buffer, broadcast_addr) | 29 | socket.send_to(&buffer, broadcast_addr).map_err(WebolError::Broadcast) |
29 | } \ No newline at end of file | 30 | } |