aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/auth.rs43
-rw-r--r--src/db.rs6
-rw-r--r--src/error.rs88
-rw-r--r--src/extractors.rs24
-rw-r--r--src/main.rs60
-rw-r--r--src/routes.rs (renamed from src/routes/mod.rs)0
-rw-r--r--src/routes/device.rs137
-rw-r--r--src/routes/start.rs138
-rw-r--r--src/routes/status.rs79
-rw-r--r--src/services.rs (renamed from src/services/mod.rs)0
-rw-r--r--src/services/ping.rs154
-rw-r--r--src/wol.rs18
12 files changed, 416 insertions, 331 deletions
diff --git a/src/auth.rs b/src/auth.rs
deleted file mode 100644
index feca652..0000000
--- a/src/auth.rs
+++ /dev/null
@@ -1,43 +0,0 @@
1use axum::http::{StatusCode, HeaderValue};
2use axum::http::header::ToStrError;
3use tracing::{debug, error, trace};
4use crate::auth::Error::{MissingSecret, WrongSecret};
5use crate::config::Config;
6
7pub fn auth(config: &Config, secret: Option<&HeaderValue>) -> Result<bool, Error> {
8 debug!("auth request with secret {:?}", secret);
9 if let Some(value) = secret {
10 trace!("value exists");
11 let key = &config.apikey;
12 if value.to_str().map_err(Error::HeaderToStr)? == key.as_str() {
13 debug!("successful auth");
14 Ok(true)
15 } else {
16 debug!("unsuccessful auth (wrong secret)");
17 Err(WrongSecret)
18 }
19 } else {
20 debug!("unsuccessful auth (no secret)");
21 Err(MissingSecret)
22 }
23}
24
25#[derive(Debug)]
26pub enum Error {
27 WrongSecret,
28 MissingSecret,
29 HeaderToStr(ToStrError)
30}
31
32impl Error {
33 pub fn get(self) -> (StatusCode, &'static str) {
34 match self {
35 Self::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
36 Self::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"),
37 Self::HeaderToStr(err) => {
38 error!("server error: {}", err.to_string());
39 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error")
40 },
41 }
42 }
43}
diff --git a/src/db.rs b/src/db.rs
index 489a000..47e907d 100644
--- a/src/db.rs
+++ b/src/db.rs
@@ -1,13 +1,13 @@
1use serde::Serialize; 1use serde::Serialize;
2use sqlx::{PgPool, postgres::PgPoolOptions}; 2use sqlx::{PgPool, postgres::PgPoolOptions, types::{ipnetwork::IpNetwork, mac_address::MacAddress}};
3use tracing::{debug, info}; 3use tracing::{debug, info};
4 4
5#[derive(Serialize, Debug)] 5#[derive(Serialize, Debug)]
6pub struct Device { 6pub struct Device {
7 pub id: String, 7 pub id: String,
8 pub mac: String, 8 pub mac: MacAddress,
9 pub broadcast_addr: String, 9 pub broadcast_addr: String,
10 pub ip: String, 10 pub ip: IpNetwork,
11 pub times: Option<Vec<i64>> 11 pub times: Option<Vec<i64>>
12} 12}
13 13
diff --git a/src/error.rs b/src/error.rs
index 56d6c52..513b51b 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,44 +1,80 @@
1use std::io; 1use ::ipnetwork::IpNetworkError;
2use axum::http::header::ToStrError;
2use axum::http::StatusCode; 3use axum::http::StatusCode;
3use axum::Json;
4use axum::response::{IntoResponse, Response}; 4use axum::response::{IntoResponse, Response};
5use axum::Json;
6use mac_address::MacParseError;
5use serde_json::json; 7use serde_json::json;
8use std::io;
6use tracing::error; 9use tracing::error;
7use crate::auth::Error as AuthError;
8 10
9#[derive(Debug)] 11#[derive(Debug, thiserror::Error)]
10pub enum Error { 12pub enum Error {
11 Generic, 13 #[error("db: {source}")]
12 Auth(AuthError), 14 Db {
13 DB(sqlx::Error), 15 #[from]
14 IpParse(<std::net::IpAddr as std::str::FromStr>::Err), 16 source: sqlx::Error,
15 BufferParse(std::num::ParseIntError), 17 },
16 Broadcast(io::Error), 18
19 #[error("buffer parse: {source}")]
20 ParseInt {
21 #[from]
22 source: std::num::ParseIntError,
23 },
24
25 #[error("header parse: {source}")]
26 ParseHeader {
27 #[from]
28 source: ToStrError,
29 },
30
31 #[error("string parse: {source}")]
32 IpParse {
33 #[from]
34 source: IpNetworkError,
35 },
36
37 #[error("mac parse: {source}")]
38 MacParse {
39 #[from]
40 source: MacParseError,
41 },
42
43 #[error("io: {source}")]
44 Io {
45 #[from]
46 source: io::Error,
47 },
17} 48}
18 49
19impl IntoResponse for Error { 50impl IntoResponse for Error {
20 fn into_response(self) -> Response { 51 fn into_response(self) -> Response {
52 error!("{}", self.to_string());
21 let (status, error_message) = match self { 53 let (status, error_message) = match self {
22 Self::Auth(err) => { 54 Self::Db { source } => {
23 err.get() 55 error!("{source}");
24 }, 56 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error")
25 Self::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""), 57 }
26 Self::IpParse(err) => { 58 Self::Io { source } => {
27 error!("server error: {}", err.to_string()); 59 error!("{source}");
60 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error")
61 }
62 Self::ParseHeader { source } => {
63 error!("{source}");
28 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") 64 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error")
29 }, 65 }
30 Self::DB(err) => { 66 Self::ParseInt { source } => {
31 error!("server error: {}", err.to_string()); 67 error!("{source}");
32 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") 68 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error")
33 }, 69 }
34 Self::Broadcast(err) => { 70 Self::MacParse { source } => {
35 error!("server error: {}", err.to_string()); 71 error!("{source}");
36 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") 72 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error")
37 }, 73 }
38 Self::BufferParse(err) => { 74 Self::IpParse { source } => {
39 error!("server error: {}", err.to_string()); 75 error!("{source}");
40 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") 76 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error")
41 }, 77 }
42 }; 78 };
43 let body = Json(json!({ 79 let body = Json(json!({
44 "error": error_message, 80 "error": error_message,
diff --git a/src/extractors.rs b/src/extractors.rs
new file mode 100644
index 0000000..4d441e9
--- /dev/null
+++ b/src/extractors.rs
@@ -0,0 +1,24 @@
1use axum::{
2 extract::{Request, State},
3 http::{HeaderMap, StatusCode},
4 middleware::Next,
5 response::Response,
6};
7
8use crate::AppState;
9
10pub async fn auth(
11 State(state): State<AppState>,
12 headers: HeaderMap,
13 request: Request,
14 next: Next,
15) -> Result<Response, StatusCode> {
16 let secret = headers.get("authorization");
17 match secret {
18 Some(token) if token == state.config.apikey.as_str() => {
19 let response = next.run(request).await;
20 Ok(response)
21 }
22 _ => Err(StatusCode::UNAUTHORIZED),
23 }
24}
diff --git a/src/main.rs b/src/main.rs
index 4ef129b..d17984f 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,42 +1,44 @@
1use std::env;
2use std::sync::Arc;
3use axum::{Router, routing::post};
4use axum::routing::{get, put};
5use dashmap::DashMap;
6use sqlx::PgPool;
7use time::util::local_offset;
8use tokio::sync::broadcast::{channel, Sender};
9use tracing::{info, level_filters::LevelFilter};
10use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*};
11use crate::config::Config; 1use crate::config::Config;
12use crate::db::init_db_pool; 2use crate::db::init_db_pool;
13use crate::routes::device; 3use crate::routes::device;
14use crate::routes::start::start; 4use crate::routes::start::start;
15use crate::routes::status::status; 5use crate::routes::status::status;
16use crate::services::ping::{BroadcastCommands, StatusMap}; 6use crate::services::ping::StatusMap;
7use axum::middleware::from_fn_with_state;
8use axum::routing::{get, put};
9use axum::{routing::post, Router};
10use dashmap::DashMap;
11use services::ping::BroadcastCommand;
12use sqlx::PgPool;
13use std::env;
14use std::sync::Arc;
15use tokio::sync::broadcast::{channel, Sender};
16use tracing::{info, level_filters::LevelFilter};
17use tracing_subscriber::fmt::time::UtcTime;
18use tracing_subscriber::{fmt, prelude::*, EnvFilter};
17 19
18mod auth;
19mod config; 20mod config;
20mod routes;
21mod wol;
22mod db; 21mod db;
23mod error; 22mod error;
23mod extractors;
24mod routes;
24mod services; 25mod services;
26mod wol;
25 27
26#[tokio::main] 28#[tokio::main]
27async fn main() -> color_eyre::eyre::Result<()> { 29async fn main() -> color_eyre::eyre::Result<()> {
28
29 color_eyre::install()?; 30 color_eyre::install()?;
30 31
31 unsafe { local_offset::set_soundness(local_offset::Soundness::Unsound); }
32 let time_format = 32 let time_format =
33 time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"); 33 time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]");
34 let loc = LocalTime::new(time_format); 34 let loc = UtcTime::new(time_format);
35
36 let file_appender = tracing_appender::rolling::daily("logs", "webol.log");
37 let (non_blocking, _guard) = tracing_appender::non_blocking(file_appender);
35 38
36 tracing_subscriber::registry() 39 tracing_subscriber::registry()
37 .with(fmt::layer() 40 .with(fmt::layer().with_writer(non_blocking).with_ansi(false))
38 .with_timer(loc) 41 .with(fmt::layer().with_timer(loc))
39 )
40 .with( 42 .with(
41 EnvFilter::builder() 43 EnvFilter::builder()
42 .with_default_directive(LevelFilter::INFO.into()) 44 .with_default_directive(LevelFilter::INFO.into())
@@ -56,8 +58,13 @@ async fn main() -> color_eyre::eyre::Result<()> {
56 let (tx, _) = channel(32); 58 let (tx, _) = channel(32);
57 59
58 let ping_map: StatusMap = DashMap::new(); 60 let ping_map: StatusMap = DashMap::new();
59 61
60 let shared_state = Arc::new(AppState { db, config: config.clone(), ping_send: tx, ping_map }); 62 let shared_state = AppState {
63 db,
64 config: config.clone(),
65 ping_send: tx,
66 ping_map,
67 };
61 68
62 let app = Router::new() 69 let app = Router::new()
63 .route("/start", post(start)) 70 .route("/start", post(start))
@@ -65,20 +72,21 @@ async fn main() -> color_eyre::eyre::Result<()> {
65 .route("/device", put(device::put)) 72 .route("/device", put(device::put))
66 .route("/device", post(device::post)) 73 .route("/device", post(device::post))
67 .route("/status", get(status)) 74 .route("/status", get(status))
68 .with_state(shared_state); 75 .route_layer(from_fn_with_state(shared_state.clone(), extractors::auth))
76 .with_state(Arc::new(shared_state));
69 77
70 let addr = config.serveraddr; 78 let addr = config.serveraddr;
71 info!("start server on {}", addr); 79 info!("start server on {}", addr);
72 let listener = tokio::net::TcpListener::bind(addr) 80 let listener = tokio::net::TcpListener::bind(addr).await?;
73 .await?;
74 axum::serve(listener, app).await?; 81 axum::serve(listener, app).await?;
75 82
76 Ok(()) 83 Ok(())
77} 84}
78 85
86#[derive(Clone)]
79pub struct AppState { 87pub struct AppState {
80 db: PgPool, 88 db: PgPool,
81 config: Config, 89 config: Config,
82 ping_send: Sender<BroadcastCommands>, 90 ping_send: Sender<BroadcastCommand>,
83 ping_map: StatusMap, 91 ping_map: StatusMap,
84} 92}
diff --git a/src/routes/mod.rs b/src/routes.rs
index d5ab0d6..d5ab0d6 100644
--- a/src/routes/mod.rs
+++ b/src/routes.rs
diff --git a/src/routes/device.rs b/src/routes/device.rs
index c85df1b..d39d98e 100644
--- a/src/routes/device.rs
+++ b/src/routes/device.rs
@@ -1,34 +1,34 @@
1use std::sync::Arc; 1use crate::db::Device;
2use crate::error::Error;
2use axum::extract::State; 3use axum::extract::State;
3use axum::Json; 4use axum::Json;
4use axum::http::HeaderMap; 5use mac_address::MacAddress;
5use serde::{Deserialize, Serialize}; 6use serde::{Deserialize, Serialize};
6use serde_json::{json, Value}; 7use serde_json::{json, Value};
8use sqlx::types::ipnetwork::IpNetwork;
9use std::{sync::Arc, str::FromStr};
7use tracing::{debug, info}; 10use tracing::{debug, info};
8use crate::auth::auth;
9use crate::db::Device;
10use crate::error::Error;
11 11
12pub async fn get(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<GetDevicePayload>) -> Result<Json<Value>, Error> { 12pub async fn get(
13 info!("add device {}", payload.id); 13 State(state): State<Arc<crate::AppState>>,
14 let secret = headers.get("authorization"); 14 Json(payload): Json<GetDevicePayload>,
15 if auth(&state.config, secret).map_err(Error::Auth)? { 15) -> Result<Json<Value>, Error> {
16 let device = sqlx::query_as!( 16 info!("get device {}", payload.id);
17 Device, 17 let device = sqlx::query_as!(
18 r#" 18 Device,
19 SELECT id, mac, broadcast_addr, ip, times 19 r#"
20 FROM devices 20 SELECT id, mac, broadcast_addr, ip, times
21 WHERE id = $1; 21 FROM devices
22 "#, 22 WHERE id = $1;
23 payload.id 23 "#,
24 ).fetch_one(&state.db).await.map_err(Error::DB)?; 24 payload.id
25 )
26 .fetch_one(&state.db)
27 .await?;
25 28
26 debug!("got device {:?}", device); 29 debug!("got device {:?}", device);
27 30
28 Ok(Json(json!(device))) 31 Ok(Json(json!(device)))
29 } else {
30 Err(Error::Generic)
31 }
32} 32}
33 33
34#[derive(Deserialize)] 34#[derive(Deserialize)]
@@ -36,25 +36,31 @@ pub struct GetDevicePayload {
36 id: String, 36 id: String,
37} 37}
38 38
39pub async fn put(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PutDevicePayload>) -> Result<Json<Value>, Error> { 39pub async fn put(
40 info!("add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); 40 State(state): State<Arc<crate::AppState>>,
41 let secret = headers.get("authorization"); 41 Json(payload): Json<PutDevicePayload>,
42 if auth(&state.config, secret).map_err(Error::Auth)? { 42) -> Result<Json<Value>, Error> {
43 sqlx::query!( 43 info!(
44 r#" 44 "add device {} ({}, {}, {})",
45 INSERT INTO devices (id, mac, broadcast_addr, ip) 45 payload.id, payload.mac, payload.broadcast_addr, payload.ip
46 VALUES ($1, $2, $3, $4); 46 );
47 "#, 47
48 payload.id, 48 let ip = IpNetwork::from_str(&payload.ip)?;
49 payload.mac, 49 let mac = MacAddress::from_str(&payload.mac)?;
50 payload.broadcast_addr, 50 sqlx::query!(
51 payload.ip 51 r#"
52 ).execute(&state.db).await.map_err(Error::DB)?; 52 INSERT INTO devices (id, mac, broadcast_addr, ip)
53 VALUES ($1, $2, $3, $4);
54 "#,
55 payload.id,
56 mac,
57 payload.broadcast_addr,
58 ip
59 )
60 .execute(&state.db)
61 .await?;
53 62
54 Ok(Json(json!(PutDeviceResponse { success: true }))) 63 Ok(Json(json!(PutDeviceResponse { success: true })))
55 } else {
56 Err(Error::Generic)
57 }
58} 64}
59 65
60#[derive(Deserialize)] 66#[derive(Deserialize)]
@@ -62,35 +68,40 @@ pub struct PutDevicePayload {
62 id: String, 68 id: String,
63 mac: String, 69 mac: String,
64 broadcast_addr: String, 70 broadcast_addr: String,
65 ip: String 71 ip: String,
66} 72}
67 73
68#[derive(Serialize)] 74#[derive(Serialize)]
69pub struct PutDeviceResponse { 75pub struct PutDeviceResponse {
70 success: bool 76 success: bool,
71} 77}
72 78
73pub async fn post(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PostDevicePayload>) -> Result<Json<Value>, Error> { 79pub async fn post(
74 info!("edit device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); 80 State(state): State<Arc<crate::AppState>>,
75 let secret = headers.get("authorization"); 81 Json(payload): Json<PostDevicePayload>,
76 if auth(&state.config, secret).map_err(Error::Auth)? { 82) -> Result<Json<Value>, Error> {
77 let device = sqlx::query_as!( 83 info!(
78 Device, 84 "edit device {} ({}, {}, {})",
79 r#" 85 payload.id, payload.mac, payload.broadcast_addr, payload.ip
80 UPDATE devices 86 );
81 SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 87 let ip = IpNetwork::from_str(&payload.ip)?;
82 RETURNING id, mac, broadcast_addr, ip, times; 88 let mac = MacAddress::from_str(&payload.mac)?;
83 "#, 89 let device = sqlx::query_as!(
84 payload.mac, 90 Device,
85 payload.broadcast_addr, 91 r#"
86 payload.ip, 92 UPDATE devices
87 payload.id 93 SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4
88 ).fetch_one(&state.db).await.map_err(Error::DB)?; 94 RETURNING id, mac, broadcast_addr, ip, times;
95 "#,
96 mac,
97 payload.broadcast_addr,
98 ip,
99 payload.id
100 )
101 .fetch_one(&state.db)
102 .await?;
89 103
90 Ok(Json(json!(device))) 104 Ok(Json(json!(device)))
91 } else {
92 Err(Error::Generic)
93 }
94} 105}
95 106
96#[derive(Deserialize)] 107#[derive(Deserialize)]
diff --git a/src/routes/start.rs b/src/routes/start.rs
index ce95bf3..d4c0802 100644
--- a/src/routes/start.rs
+++ b/src/routes/start.rs
@@ -1,10 +1,8 @@
1use crate::auth::auth;
2use crate::db::Device; 1use crate::db::Device;
3use crate::error::Error; 2use crate::error::Error;
4use crate::services::ping::Value as PingValue; 3use crate::services::ping::Value as PingValue;
5use crate::wol::{create_buffer, send_packet}; 4use crate::wol::{create_buffer, send_packet};
6use axum::extract::State; 5use axum::extract::State;
7use axum::http::HeaderMap;
8use axum::Json; 6use axum::Json;
9use serde::{Deserialize, Serialize}; 7use serde::{Deserialize, Serialize};
10use serde_json::{json, Value}; 8use serde_json::{json, Value};
@@ -12,86 +10,82 @@ use std::sync::Arc;
12use tracing::{debug, info}; 10use tracing::{debug, info};
13use uuid::Uuid; 11use uuid::Uuid;
14 12
15#[axum_macros::debug_handler]
16pub async fn start( 13pub async fn start(
17 State(state): State<Arc<crate::AppState>>, 14 State(state): State<Arc<crate::AppState>>,
18 headers: HeaderMap,
19 Json(payload): Json<Payload>, 15 Json(payload): Json<Payload>,
20) -> Result<Json<Value>, Error> { 16) -> Result<Json<Value>, Error> {
21 info!("POST request"); 17 info!("POST request");
22 let secret = headers.get("authorization"); 18 let device = sqlx::query_as!(
23 let authorized = auth(&state.config, secret).map_err(Error::Auth)?; 19 Device,
24 if authorized { 20 r#"
25 let device = sqlx::query_as!( 21 SELECT id, mac, broadcast_addr, ip, times
26 Device, 22 FROM devices
27 r#" 23 WHERE id = $1;
28 SELECT id, mac, broadcast_addr, ip, times 24 "#,
29 FROM devices 25 payload.id
30 WHERE id = $1; 26 )
31 "#, 27 .fetch_one(&state.db)
32 payload.id 28 .await?;
33 )
34 .fetch_one(&state.db)
35 .await
36 .map_err(Error::DB)?;
37
38 info!("starting {}", device.id);
39
40 let bind_addr = "0.0.0.0:0";
41 29
42 let _ = send_packet( 30 info!("starting {}", device.id);
43 &bind_addr.parse().map_err(Error::IpParse)?,
44 &device.broadcast_addr.parse().map_err(Error::IpParse)?,
45 &create_buffer(&device.mac)?,
46 )?;
47 let dev_id = device.id.clone();
48 let uuid = if payload.ping.is_some_and(|ping| ping) {
49 let mut uuid: Option<String> = None;
50 for (key, value) in state.ping_map.clone() {
51 if value.ip == device.ip {
52 debug!("service already exists");
53 uuid = Some(key);
54 break;
55 }
56 }
57 let uuid_gen = match uuid {
58 Some(u) => u,
59 None => Uuid::new_v4().to_string(),
60 };
61 let uuid_genc = uuid_gen.clone();
62 31
63 tokio::spawn(async move { 32 let bind_addr = "0.0.0.0:0";
64 debug!("init ping service");
65 state.ping_map.insert(
66 uuid_gen.clone(),
67 PingValue {
68 ip: device.ip.clone(),
69 online: false,
70 },
71 );
72 33
73 crate::services::ping::spawn( 34 let _ = send_packet(
74 state.ping_send.clone(), 35 bind_addr,
75 &state.config, 36 &device.broadcast_addr,
76 device, 37 &create_buffer(&device.mac.to_string())?,
77 uuid_gen.clone(), 38 )?;
78 &state.ping_map, 39 let dev_id = device.id.clone();
79 &state.db, 40 let uuid = if payload.ping.is_some_and(|ping| ping) {
80 ) 41 Some(setup_ping(state, device))
81 .await;
82 });
83 Some(uuid_genc)
84 } else {
85 None
86 };
87 Ok(Json(json!(Response {
88 id: dev_id,
89 boot: true,
90 uuid
91 })))
92 } else { 42 } else {
93 Err(Error::Generic) 43 None
44 };
45 Ok(Json(json!(Response {
46 id: dev_id,
47 boot: true,
48 uuid
49 })))
50}
51
52fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String {
53 let mut uuid: Option<String> = None;
54 for (key, value) in state.ping_map.clone() {
55 if value.ip == device.ip {
56 debug!("service already exists");
57 uuid = Some(key);
58 break;
59 }
94 } 60 }
61 let uuid_gen = match uuid {
62 Some(u) => u,
63 None => Uuid::new_v4().to_string(),
64 };
65 let uuid_ret = uuid_gen.clone();
66
67 debug!("init ping service");
68 state.ping_map.insert(
69 uuid_gen.clone(),
70 PingValue {
71 ip: device.ip,
72 online: false,
73 },
74 );
75
76 tokio::spawn(async move {
77 crate::services::ping::spawn(
78 state.ping_send.clone(),
79 &state.config,
80 device,
81 uuid_gen,
82 &state.ping_map,
83 &state.db,
84 )
85 .await;
86 });
87
88 uuid_ret
95} 89}
96 90
97#[derive(Deserialize)] 91#[derive(Deserialize)]
diff --git a/src/routes/status.rs b/src/routes/status.rs
index 31ef996..0e25f7d 100644
--- a/src/routes/status.rs
+++ b/src/routes/status.rs
@@ -1,10 +1,79 @@
1use std::sync::Arc; 1use crate::services::ping::BroadcastCommand;
2use crate::AppState;
3use axum::extract::ws::{Message, WebSocket};
2use axum::extract::{State, WebSocketUpgrade}; 4use axum::extract::{State, WebSocketUpgrade};
3use axum::response::Response; 5use axum::response::Response;
4use crate::AppState; 6use sqlx::PgPool;
5use crate::services::ping::status_websocket; 7use std::sync::Arc;
8use tracing::{debug, trace};
6 9
7#[axum_macros::debug_handler]
8pub async fn status(State(state): State<Arc<AppState>>, ws: WebSocketUpgrade) -> Response { 10pub async fn status(State(state): State<Arc<AppState>>, ws: WebSocketUpgrade) -> Response {
9 ws.on_upgrade(move |socket| status_websocket(socket, state)) 11 ws.on_upgrade(move |socket| websocket(socket, state))
12}
13
14pub async fn websocket(mut socket: WebSocket, state: Arc<AppState>) {
15 trace!("wait for ws message (uuid)");
16 let msg = socket.recv().await;
17 let uuid = msg.unwrap().unwrap().into_text().unwrap();
18
19 trace!("Search for uuid: {}", uuid);
20
21 let eta = get_eta(&state.db).await;
22 let _ = socket
23 .send(Message::Text(format!("eta_{eta}_{uuid}")))
24 .await;
25
26 let device_exists = state.ping_map.contains_key(&uuid);
27 if device_exists {
28 let _ = socket
29 .send(receive_ping_broadcast(state.clone(), uuid).await)
30 .await;
31 } else {
32 debug!("didn't find any device");
33 let _ = socket.send(Message::Text(format!("notfound_{uuid}"))).await;
34 };
35
36 let _ = socket.close().await;
37}
38
39async fn receive_ping_broadcast(state: Arc<AppState>, uuid: String) -> Message {
40 let pm = state.ping_map.clone().into_read_only();
41 let device = pm.get(&uuid).expect("fatal error");
42 debug!("got device: {} (online: {})", device.ip, device.online);
43 if device.online {
44 debug!("already started");
45 Message::Text(BroadcastCommand::success(uuid).to_string())
46 } else {
47 loop {
48 trace!("wait for tx message");
49 let message = state
50 .ping_send
51 .subscribe()
52 .recv()
53 .await
54 .expect("fatal error");
55 trace!("got message {:?}", message);
56
57 if message.uuid != uuid {
58 continue;
59 }
60 trace!("message == uuid success");
61 return Message::Text(message.to_string());
62 }
63 }
64}
65
66async fn get_eta(db: &PgPool) -> i64 {
67 let query = sqlx::query!(r#"SELECT times FROM devices;"#)
68 .fetch_one(db)
69 .await
70 .unwrap();
71
72 let times = if let Some(times) = query.times {
73 times
74 } else {
75 vec![0]
76 };
77
78 times.iter().sum::<i64>() / i64::try_from(times.len()).unwrap()
10} 79}
diff --git a/src/services/mod.rs b/src/services.rs
index a766209..a766209 100644
--- a/src/services/mod.rs
+++ b/src/services.rs
diff --git a/src/services/ping.rs b/src/services/ping.rs
index 9b164c8..9191f86 100644
--- a/src/services/ping.rs
+++ b/src/services/ping.rs
@@ -1,59 +1,58 @@
1use std::str::FromStr; 1use crate::config::Config;
2use std::net::IpAddr; 2use crate::db::Device;
3use std::sync::Arc;
4
5use axum::extract::ws::WebSocket;
6use axum::extract::ws::Message;
7use dashmap::DashMap; 3use dashmap::DashMap;
4use ipnetwork::IpNetwork;
8use sqlx::PgPool; 5use sqlx::PgPool;
6use std::fmt::Display;
9use time::{Duration, Instant}; 7use time::{Duration, Instant};
10use tokio::sync::broadcast::Sender; 8use tokio::sync::broadcast::Sender;
11use tracing::{debug, error, trace}; 9use tracing::{debug, error, trace};
12use crate::AppState;
13use crate::config::Config;
14use crate::db::Device;
15 10
16pub type StatusMap = DashMap<String, Value>; 11pub type StatusMap = DashMap<String, Value>;
17 12
18#[derive(Debug, Clone)] 13#[derive(Debug, Clone)]
19pub struct Value { 14pub struct Value {
20 pub ip: String, 15 pub ip: IpNetwork,
21 pub online: bool 16 pub online: bool,
22} 17}
23 18
24pub async fn spawn(tx: Sender<BroadcastCommands>, config: &Config, device: Device, uuid: String, ping_map: &StatusMap, db: &PgPool) { 19pub async fn spawn(
20 tx: Sender<BroadcastCommand>,
21 config: &Config,
22 device: Device,
23 uuid: String,
24 ping_map: &StatusMap,
25 db: &PgPool,
26) {
25 let timer = Instant::now(); 27 let timer = Instant::now();
26 let payload = [0; 8]; 28 let payload = [0; 8];
27 29
28 let ping_ip = IpAddr::from_str(&device.ip).expect("bad ip"); 30 let mut msg: Option<BroadcastCommand> = None;
29
30 let mut msg: Option<BroadcastCommands> = None;
31 while msg.is_none() { 31 while msg.is_none() {
32 let ping = surge_ping::ping( 32 let ping = surge_ping::ping(device.ip.ip(), &payload).await;
33 ping_ip,
34 &payload
35 ).await;
36 33
37 if let Err(ping) = ping { 34 if let Err(ping) = ping {
38 let ping_timeout = matches!(ping, surge_ping::SurgeError::Timeout { .. }); 35 let ping_timeout = matches!(ping, surge_ping::SurgeError::Timeout { .. });
39 if !ping_timeout { 36 if !ping_timeout {
40 error!("{}", ping.to_string()); 37 error!("{}", ping.to_string());
41 msg = Some(BroadcastCommands::Error(uuid.clone())); 38 msg = Some(BroadcastCommand::error(uuid.clone()));
42 } 39 }
43 if timer.elapsed() >= Duration::minutes(config.pingtimeout) { 40 if timer.elapsed() >= Duration::minutes(config.pingtimeout) {
44 msg = Some(BroadcastCommands::Timeout(uuid.clone())); 41 msg = Some(BroadcastCommand::timeout(uuid.clone()));
45 } 42 }
46 } else { 43 } else {
47 let (_, duration) = ping.map_err(|err| error!("{}", err.to_string())).expect("fatal error"); 44 let (_, duration) = ping
45 .map_err(|err| error!("{}", err.to_string()))
46 .expect("fatal error");
48 debug!("ping took {:?}", duration); 47 debug!("ping took {:?}", duration);
49 msg = Some(BroadcastCommands::Success(uuid.clone())); 48 msg = Some(BroadcastCommand::success(uuid.clone()));
50 }; 49 };
51 } 50 }
52 51
53 let msg = msg.expect("fatal error"); 52 let msg = msg.expect("fatal error");
54 53
55 let _ = tx.send(msg.clone()); 54 let _ = tx.send(msg.clone());
56 if let BroadcastCommands::Success(..) = msg { 55 if let BroadcastCommands::Success = msg.command {
57 sqlx::query!( 56 sqlx::query!(
58 r#" 57 r#"
59 UPDATE devices 58 UPDATE devices
@@ -62,8 +61,17 @@ pub async fn spawn(tx: Sender<BroadcastCommands>, config: &Config, device: Devic
62 "#, 61 "#,
63 timer.elapsed().whole_seconds(), 62 timer.elapsed().whole_seconds(),
64 device.id 63 device.id
65 ).execute(db).await.unwrap(); 64 )
66 ping_map.insert(uuid.clone(), Value { ip: device.ip.clone(), online: true }); 65 .execute(db)
66 .await
67 .unwrap();
68 ping_map.insert(
69 uuid.clone(),
70 Value {
71 ip: device.ip,
72 online: true,
73 },
74 );
67 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; 75 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
68 } 76 }
69 trace!("remove {} from ping_map", uuid); 77 trace!("remove {} from ping_map", uuid);
@@ -72,74 +80,48 @@ pub async fn spawn(tx: Sender<BroadcastCommands>, config: &Config, device: Devic
72 80
73#[derive(Clone, Debug, PartialEq)] 81#[derive(Clone, Debug, PartialEq)]
74pub enum BroadcastCommands { 82pub enum BroadcastCommands {
75 Success(String), 83 Success,
76 Timeout(String), 84 Timeout,
77 Error(String), 85 Error,
78} 86}
79 87
80pub async fn status_websocket(mut socket: WebSocket, state: Arc<AppState>) { 88#[derive(Clone, Debug, PartialEq)]
81 trace!("wait for ws message (uuid)"); 89pub struct BroadcastCommand {
82 let msg = socket.recv().await; 90 pub uuid: String,
83 let uuid = msg.unwrap().unwrap().into_text().unwrap(); 91 pub command: BroadcastCommands,
84 92}
85 trace!("Search for uuid: {}", uuid);
86
87 let eta = get_eta(&state.db).await;
88 let _ = socket.send(Message::Text(format!("eta_{eta}_{uuid}"))).await;
89 93
90 let device_exists = state.ping_map.contains_key(&uuid); 94impl Display for BroadcastCommand {
91 if device_exists { 95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 let _ = socket.send(process_device(state.clone(), uuid).await).await; 96 let prefix = match self.command {
93 } else { 97 BroadcastCommands::Success => "start",
94 debug!("didn't find any device"); 98 BroadcastCommands::Timeout => "timeout",
95 let _ = socket.send(Message::Text(format!("notfound_{uuid}"))).await; 99 BroadcastCommands::Error => "error",
96 }; 100 };
97 101
98 let _ = socket.close().await; 102 f.write_str(format!("{prefix}_{}", self.uuid).as_str())
103 }
99} 104}
100 105
101async fn get_eta(db: &PgPool) -> i64 { 106impl BroadcastCommand {
102 let query = sqlx::query!( 107 pub fn success(uuid: String) -> Self {
103 r#"SELECT times FROM devices;"# 108 Self {
104 ).fetch_one(db).await.unwrap(); 109 uuid,
105 110 command: BroadcastCommands::Success,
106 let times = match query.times { 111 }
107 None => { vec![0] }, 112 }
108 Some(t) => t,
109 };
110 times.iter().sum::<i64>() / i64::try_from(times.len()).unwrap()
111 113
112} 114 pub fn timeout(uuid: String) -> Self {
115 Self {
116 uuid,
117 command: BroadcastCommands::Timeout,
118 }
119 }
113 120
114async fn process_device(state: Arc<AppState>, uuid: String) -> Message { 121 pub fn error(uuid: String) -> Self {
115 let pm = state.ping_map.clone().into_read_only(); 122 Self {
116 let device = pm.get(&uuid).expect("fatal error"); 123 uuid,
117 debug!("got device: {} (online: {})", device.ip, device.online); 124 command: BroadcastCommands::Error,
118 if device.online {
119 debug!("already started");
120 Message::Text(format!("start_{uuid}"))
121 } else {
122 loop {
123 trace!("wait for tx message");
124 let message = state.ping_send.subscribe().recv().await.expect("fatal error");
125 trace!("got message {:?}", message);
126 return match message {
127 BroadcastCommands::Success(msg_uuid) => {
128 if msg_uuid != uuid { continue; }
129 trace!("message == uuid success");
130 Message::Text(format!("start_{uuid}"))
131 },
132 BroadcastCommands::Timeout(msg_uuid) => {
133 if msg_uuid != uuid { continue; }
134 trace!("message == uuid timeout");
135 Message::Text(format!("timeout_{uuid}"))
136 },
137 BroadcastCommands::Error(msg_uuid) => {
138 if msg_uuid != uuid { continue; }
139 trace!("message == uuid error");
140 Message::Text(format!("error_{uuid}"))
141 }
142 }
143 } 125 }
144 } 126 }
145} 127}
diff --git a/src/wol.rs b/src/wol.rs
index 83c0ee6..31cf350 100644
--- a/src/wol.rs
+++ b/src/wol.rs
@@ -1,4 +1,4 @@
1use std::net::{SocketAddr, UdpSocket}; 1use std::net::{ToSocketAddrs, UdpSocket};
2 2
3use crate::error::Error; 3use crate::error::Error;
4 4
@@ -11,8 +11,8 @@ pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, Error> {
11 let mut mac = Vec::new(); 11 let mut mac = Vec::new();
12 let sp = mac_addr.split(':'); 12 let sp = mac_addr.split(':');
13 for f in sp { 13 for f in sp {
14 mac.push(u8::from_str_radix(f, 16).map_err(Error::BufferParse)?); 14 mac.push(u8::from_str_radix(f, 16)?);
15 }; 15 }
16 let mut buf = vec![255; 6]; 16 let mut buf = vec![255; 6];
17 for _ in 0..16 { 17 for _ in 0..16 {
18 for i in &mac { 18 for i in &mac {
@@ -23,8 +23,12 @@ pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, Error> {
23} 23}
24 24
25/// Sends a buffer on UDP broadcast 25/// Sends a buffer on UDP broadcast
26pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: &[u8]) -> Result<usize, Error> { 26pub fn send_packet<A: ToSocketAddrs>(
27 let socket = UdpSocket::bind(bind_addr).map_err(Error::Broadcast)?; 27 bind_addr: A,
28 socket.set_broadcast(true).map_err(Error::Broadcast)?; 28 broadcast_addr: A,
29 socket.send_to(buffer, broadcast_addr).map_err(Error::Broadcast) 29 buffer: &[u8],
30) -> Result<usize, Error> {
31 let socket = UdpSocket::bind(bind_addr)?;
32 socket.set_broadcast(true)?;
33 Ok(socket.send_to(buffer, broadcast_addr)?)
30} 34}