summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFxQnLr <[email protected]>2024-02-12 16:00:45 +0100
committerGitHub <[email protected]>2024-02-12 16:00:45 +0100
commitc663810817183c8f92a4279236ca84d271365088 (patch)
tree0c844cc883e5e474a9cdad30004108852f13f903 /src
parentda6367885d31698464e1bec122e3e673974427c6 (diff)
parent9139d76cb1cf462820b2ddfa80d9a8d55bb30996 (diff)
downloadwebol-c663810817183c8f92a4279236ca84d271365088.tar
webol-c663810817183c8f92a4279236ca84d271365088.tar.gz
webol-c663810817183c8f92a4279236ca84d271365088.zip
Merge pull request #14 from FxQnLr/axum7
Axum7 & config changes
Diffstat (limited to 'src')
-rw-r--r--src/auth.rs24
-rw-r--r--src/config.rs32
-rw-r--r--src/db.rs16
-rw-r--r--src/error.rs8
-rw-r--r--src/main.rs41
-rw-r--r--src/routes/device.rs28
-rw-r--r--src/routes/start.rs76
-rw-r--r--src/routes/status.rs2
-rw-r--r--src/services/ping.rs84
-rw-r--r--src/wol.rs14
10 files changed, 172 insertions, 153 deletions
diff --git a/src/auth.rs b/src/auth.rs
index e4b1c2f..feca652 100644
--- a/src/auth.rs
+++ b/src/auth.rs
@@ -1,18 +1,15 @@
1use axum::headers::HeaderValue; 1use axum::http::{StatusCode, HeaderValue};
2use axum::http::StatusCode;
3use axum::http::header::ToStrError; 2use axum::http::header::ToStrError;
4use tracing::{debug, error, trace}; 3use tracing::{debug, error, trace};
5use crate::auth::AuthError::{MissingSecret, WrongSecret}; 4use crate::auth::Error::{MissingSecret, WrongSecret};
6use crate::config::SETTINGS; 5use crate::config::Config;
7 6
8pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> { 7pub fn auth(config: &Config, secret: Option<&HeaderValue>) -> Result<bool, Error> {
9 debug!("auth request with secret {:?}", secret); 8 debug!("auth request with secret {:?}", secret);
10 if let Some(value) = secret { 9 if let Some(value) = secret {
11 trace!("value exists"); 10 trace!("value exists");
12 let key = SETTINGS 11 let key = &config.apikey;
13 .get_string("apikey") 12 if value.to_str().map_err(Error::HeaderToStr)? == key.as_str() {
14 .map_err(AuthError::Config)?;
15 if value.to_str().map_err(AuthError::HeaderToStr)? == key.as_str() {
16 debug!("successful auth"); 13 debug!("successful auth");
17 Ok(true) 14 Ok(true)
18 } else { 15 } else {
@@ -26,22 +23,17 @@ pub fn auth(secret: Option<&HeaderValue>) -> Result<bool, AuthError> {
26} 23}
27 24
28#[derive(Debug)] 25#[derive(Debug)]
29pub enum AuthError { 26pub enum Error {
30 WrongSecret, 27 WrongSecret,
31 MissingSecret, 28 MissingSecret,
32 Config(config::ConfigError),
33 HeaderToStr(ToStrError) 29 HeaderToStr(ToStrError)
34} 30}
35 31
36impl AuthError { 32impl Error {
37 pub fn get(self) -> (StatusCode, &'static str) { 33 pub fn get(self) -> (StatusCode, &'static str) {
38 match self { 34 match self {
39 Self::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"), 35 Self::WrongSecret => (StatusCode::UNAUTHORIZED, "Wrong credentials"),
40 Self::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"), 36 Self::MissingSecret => (StatusCode::BAD_REQUEST, "Missing credentials"),
41 Self::Config(err) => {
42 error!("server error: {}", err.to_string());
43 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error")
44 },
45 Self::HeaderToStr(err) => { 37 Self::HeaderToStr(err) => {
46 error!("server error: {}", err.to_string()); 38 error!("server error: {}", err.to_string());
47 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") 39 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error")
diff --git a/src/config.rs b/src/config.rs
index 4c79810..4319ffc 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -1,11 +1,25 @@
1use config::Config; 1use config::File;
2use once_cell::sync::Lazy; 2use serde::Deserialize;
3 3
4pub static SETTINGS: Lazy<Config> = Lazy::new(setup); 4#[derive(Debug, Clone, Deserialize)]
5pub struct Config {
6 pub database_url: String,
7 pub apikey: String,
8 pub serveraddr: String,
9 pub pingtimeout: i64,
10}
11
12impl Config {
13 pub fn load() -> Result<Self, config::ConfigError> {
14 let config = config::Config::builder()
15 .set_default("serveraddr", "0.0.0.0:7229")?
16 .set_default("pingtimeout", 10)?
17 .add_source(File::with_name("config.toml").required(false))
18 .add_source(File::with_name("config.dev.toml").required(false))
19 .add_source(config::Environment::with_prefix("WEBOL").prefix_separator("_"))
20 .build()?;
21
22 config.try_deserialize()
23 }
24}
5 25
6fn setup() -> Config {
7 Config::builder()
8 .add_source(config::Environment::with_prefix("WEBOL").separator("_"))
9 .build()
10 .unwrap()
11} \ No newline at end of file
diff --git a/src/db.rs b/src/db.rs
index 8a6b16e..489a000 100644
--- a/src/db.rs
+++ b/src/db.rs
@@ -1,13 +1,7 @@
1#[cfg(debug_assertions)]
2use std::env;
3
4use serde::Serialize; 1use serde::Serialize;
5use sqlx::{PgPool, postgres::PgPoolOptions}; 2use sqlx::{PgPool, postgres::PgPoolOptions};
6use tracing::{debug, info}; 3use tracing::{debug, info};
7 4
8#[cfg(not(debug_assertions))]
9use crate::config::SETTINGS;
10
11#[derive(Serialize, Debug)] 5#[derive(Serialize, Debug)]
12pub struct Device { 6pub struct Device {
13 pub id: String, 7 pub id: String,
@@ -17,18 +11,12 @@ pub struct Device {
17 pub times: Option<Vec<i64>> 11 pub times: Option<Vec<i64>>
18} 12}
19 13
20pub async fn init_db_pool() -> PgPool { 14pub async fn init_db_pool(db_url: &str) -> PgPool {
21 #[cfg(not(debug_assertions))]
22 let db_url = SETTINGS.get_string("database.url").unwrap();
23
24 #[cfg(debug_assertions)]
25 let db_url = env::var("DATABASE_URL").unwrap();
26
27 debug!("attempt to connect dbPool to '{}'", db_url); 15 debug!("attempt to connect dbPool to '{}'", db_url);
28 16
29 let pool = PgPoolOptions::new() 17 let pool = PgPoolOptions::new()
30 .max_connections(5) 18 .max_connections(5)
31 .connect(&db_url) 19 .connect(db_url)
32 .await 20 .await
33 .unwrap(); 21 .unwrap();
34 22
diff --git a/src/error.rs b/src/error.rs
index 5b82534..56d6c52 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -4,10 +4,10 @@ use axum::Json;
4use axum::response::{IntoResponse, Response}; 4use axum::response::{IntoResponse, Response};
5use serde_json::json; 5use serde_json::json;
6use tracing::error; 6use tracing::error;
7use crate::auth::AuthError; 7use crate::auth::Error as AuthError;
8 8
9#[derive(Debug)] 9#[derive(Debug)]
10pub enum WebolError { 10pub enum Error {
11 Generic, 11 Generic,
12 Auth(AuthError), 12 Auth(AuthError),
13 DB(sqlx::Error), 13 DB(sqlx::Error),
@@ -16,7 +16,7 @@ pub enum WebolError {
16 Broadcast(io::Error), 16 Broadcast(io::Error),
17} 17}
18 18
19impl IntoResponse for WebolError { 19impl IntoResponse for Error {
20 fn into_response(self) -> Response { 20 fn into_response(self) -> Response {
21 let (status, error_message) = match self { 21 let (status, error_message) = match self {
22 Self::Auth(err) => { 22 Self::Auth(err) => {
@@ -45,4 +45,4 @@ impl IntoResponse for WebolError {
45 })); 45 }));
46 (status, body).into_response() 46 (status, body).into_response()
47 } 47 }
48} \ No newline at end of file 48}
diff --git a/src/main.rs b/src/main.rs
index e96b736..4ef129b 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -8,12 +8,12 @@ use time::util::local_offset;
8use tokio::sync::broadcast::{channel, Sender}; 8use tokio::sync::broadcast::{channel, Sender};
9use tracing::{info, level_filters::LevelFilter}; 9use tracing::{info, level_filters::LevelFilter};
10use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*}; 10use tracing_subscriber::{EnvFilter, fmt::{self, time::LocalTime}, prelude::*};
11use crate::config::SETTINGS; 11use crate::config::Config;
12use crate::db::init_db_pool; 12use crate::db::init_db_pool;
13use crate::routes::device::{get_device, post_device, put_device}; 13use crate::routes::device;
14use crate::routes::start::start; 14use crate::routes::start::start;
15use crate::routes::status::status; 15use crate::routes::status::status;
16use crate::services::ping::{BroadcastCommands, PingMap}; 16use crate::services::ping::{BroadcastCommands, StatusMap};
17 17
18mod auth; 18mod auth;
19mod config; 19mod config;
@@ -24,7 +24,10 @@ mod error;
24mod services; 24mod services;
25 25
26#[tokio::main] 26#[tokio::main]
27async fn main() { 27async fn main() -> color_eyre::eyre::Result<()> {
28
29 color_eyre::install()?;
30
28 unsafe { local_offset::set_soundness(local_offset::Soundness::Unsound); } 31 unsafe { local_offset::set_soundness(local_offset::Soundness::Unsound); }
29 let time_format = 32 let time_format =
30 time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"); 33 time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]");
@@ -43,35 +46,39 @@ async fn main() {
43 46
44 let version = env!("CARGO_PKG_VERSION"); 47 let version = env!("CARGO_PKG_VERSION");
45 48
49 let config = Config::load()?;
50
46 info!("start webol v{}", version); 51 info!("start webol v{}", version);
47 52
48 let db = init_db_pool().await; 53 let db = init_db_pool(&config.database_url).await;
49 sqlx::migrate!().run(&db).await.unwrap(); 54 sqlx::migrate!().run(&db).await.unwrap();
50 55
51 let (tx, _) = channel(32); 56 let (tx, _) = channel(32);
52 57
53 let ping_map: PingMap = DashMap::new(); 58 let ping_map: StatusMap = DashMap::new();
54 59
55 let shared_state = Arc::new(AppState { db, ping_send: tx, ping_map }); 60 let shared_state = Arc::new(AppState { db, config: config.clone(), ping_send: tx, ping_map });
56 61
57 let app = Router::new() 62 let app = Router::new()
58 .route("/start", post(start)) 63 .route("/start", post(start))
59 .route("/device", get(get_device)) 64 .route("/device", get(device::get))
60 .route("/device", put(put_device)) 65 .route("/device", put(device::put))
61 .route("/device", post(post_device)) 66 .route("/device", post(device::post))
62 .route("/status", get(status)) 67 .route("/status", get(status))
63 .with_state(shared_state); 68 .with_state(shared_state);
64 69
65 let addr = SETTINGS.get_string("serveraddr").unwrap_or("0.0.0.0:7229".to_string()); 70 let addr = config.serveraddr;
66 info!("start server on {}", addr); 71 info!("start server on {}", addr);
67 axum::Server::bind(&addr.parse().unwrap()) 72 let listener = tokio::net::TcpListener::bind(addr)
68 .serve(app.into_make_service()) 73 .await?;
69 .await 74 axum::serve(listener, app).await?;
70 .unwrap(); 75
76 Ok(())
71} 77}
72 78
73pub struct AppState { 79pub struct AppState {
74 db: PgPool, 80 db: PgPool,
81 config: Config,
75 ping_send: Sender<BroadcastCommands>, 82 ping_send: Sender<BroadcastCommands>,
76 ping_map: PingMap, 83 ping_map: StatusMap,
77} \ No newline at end of file 84}
diff --git a/src/routes/device.rs b/src/routes/device.rs
index 678d117..c85df1b 100644
--- a/src/routes/device.rs
+++ b/src/routes/device.rs
@@ -1,18 +1,18 @@
1use std::sync::Arc; 1use std::sync::Arc;
2use axum::extract::State; 2use axum::extract::State;
3use axum::headers::HeaderMap;
4use axum::Json; 3use axum::Json;
4use axum::http::HeaderMap;
5use serde::{Deserialize, Serialize}; 5use serde::{Deserialize, Serialize};
6use serde_json::{json, Value}; 6use serde_json::{json, Value};
7use tracing::{debug, info}; 7use tracing::{debug, info};
8use crate::auth::auth; 8use crate::auth::auth;
9use crate::db::Device; 9use crate::db::Device;
10use crate::error::WebolError; 10use crate::error::Error;
11 11
12pub async fn get_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<GetDevicePayload>) -> Result<Json<Value>, WebolError> { 12pub async fn get(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<GetDevicePayload>) -> Result<Json<Value>, Error> {
13 info!("add device {}", payload.id); 13 info!("add device {}", payload.id);
14 let secret = headers.get("authorization"); 14 let secret = headers.get("authorization");
15 if auth(secret).map_err(WebolError::Auth)? { 15 if auth(&state.config, secret).map_err(Error::Auth)? {
16 let device = sqlx::query_as!( 16 let device = sqlx::query_as!(
17 Device, 17 Device,
18 r#" 18 r#"
@@ -21,13 +21,13 @@ pub async fn get_device(State(state): State<Arc<crate::AppState>>, headers: Head
21 WHERE id = $1; 21 WHERE id = $1;
22 "#, 22 "#,
23 payload.id 23 payload.id
24 ).fetch_one(&state.db).await.map_err(WebolError::DB)?; 24 ).fetch_one(&state.db).await.map_err(Error::DB)?;
25 25
26 debug!("got device {:?}", device); 26 debug!("got device {:?}", device);
27 27
28 Ok(Json(json!(device))) 28 Ok(Json(json!(device)))
29 } else { 29 } else {
30 Err(WebolError::Generic) 30 Err(Error::Generic)
31 } 31 }
32} 32}
33 33
@@ -36,10 +36,10 @@ pub struct GetDevicePayload {
36 id: String, 36 id: String,
37} 37}
38 38
39pub async fn put_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PutDevicePayload>) -> Result<Json<Value>, WebolError> { 39pub async fn put(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PutDevicePayload>) -> Result<Json<Value>, Error> {
40 info!("add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); 40 info!("add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip);
41 let secret = headers.get("authorization"); 41 let secret = headers.get("authorization");
42 if auth(secret).map_err(WebolError::Auth)? { 42 if auth(&state.config, secret).map_err(Error::Auth)? {
43 sqlx::query!( 43 sqlx::query!(
44 r#" 44 r#"
45 INSERT INTO devices (id, mac, broadcast_addr, ip) 45 INSERT INTO devices (id, mac, broadcast_addr, ip)
@@ -49,11 +49,11 @@ pub async fn put_device(State(state): State<Arc<crate::AppState>>, headers: Head
49 payload.mac, 49 payload.mac,
50 payload.broadcast_addr, 50 payload.broadcast_addr,
51 payload.ip 51 payload.ip
52 ).execute(&state.db).await.map_err(WebolError::DB)?; 52 ).execute(&state.db).await.map_err(Error::DB)?;
53 53
54 Ok(Json(json!(PutDeviceResponse { success: true }))) 54 Ok(Json(json!(PutDeviceResponse { success: true })))
55 } else { 55 } else {
56 Err(WebolError::Generic) 56 Err(Error::Generic)
57 } 57 }
58} 58}
59 59
@@ -70,10 +70,10 @@ pub struct PutDeviceResponse {
70 success: bool 70 success: bool
71} 71}
72 72
73pub async fn post_device(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PostDevicePayload>) -> Result<Json<Value>, WebolError> { 73pub async fn post(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PostDevicePayload>) -> Result<Json<Value>, Error> {
74 info!("edit device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); 74 info!("edit device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip);
75 let secret = headers.get("authorization"); 75 let secret = headers.get("authorization");
76 if auth(secret).map_err(WebolError::Auth)? { 76 if auth(&state.config, secret).map_err(Error::Auth)? {
77 let device = sqlx::query_as!( 77 let device = sqlx::query_as!(
78 Device, 78 Device,
79 r#" 79 r#"
@@ -85,11 +85,11 @@ pub async fn post_device(State(state): State<Arc<crate::AppState>>, headers: Hea
85 payload.broadcast_addr, 85 payload.broadcast_addr,
86 payload.ip, 86 payload.ip,
87 payload.id 87 payload.id
88 ).fetch_one(&state.db).await.map_err(WebolError::DB)?; 88 ).fetch_one(&state.db).await.map_err(Error::DB)?;
89 89
90 Ok(Json(json!(device))) 90 Ok(Json(json!(device)))
91 } else { 91 } else {
92 Err(WebolError::Generic) 92 Err(Error::Generic)
93 } 93 }
94} 94}
95 95
diff --git a/src/routes/start.rs b/src/routes/start.rs
index 1555db3..ce95bf3 100644
--- a/src/routes/start.rs
+++ b/src/routes/start.rs
@@ -1,23 +1,26 @@
1use axum::headers::HeaderMap; 1use crate::auth::auth;
2use crate::db::Device;
3use crate::error::Error;
4use crate::services::ping::Value as PingValue;
5use crate::wol::{create_buffer, send_packet};
6use axum::extract::State;
7use axum::http::HeaderMap;
2use axum::Json; 8use axum::Json;
3use serde::{Deserialize, Serialize}; 9use serde::{Deserialize, Serialize};
4use std::sync::Arc;
5use axum::extract::State;
6use serde_json::{json, Value}; 10use serde_json::{json, Value};
11use std::sync::Arc;
7use tracing::{debug, info}; 12use tracing::{debug, info};
8use uuid::Uuid; 13use uuid::Uuid;
9use crate::auth::auth;
10use crate::config::SETTINGS;
11use crate::wol::{create_buffer, send_packet};
12use crate::db::Device;
13use crate::error::WebolError;
14use crate::services::ping::PingValue;
15 14
16#[axum_macros::debug_handler] 15#[axum_macros::debug_handler]
17pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<StartPayload>) -> Result<Json<Value>, WebolError> { 16pub async fn start(
17 State(state): State<Arc<crate::AppState>>,
18 headers: HeaderMap,
19 Json(payload): Json<Payload>,
20) -> Result<Json<Value>, Error> {
18 info!("POST request"); 21 info!("POST request");
19 let secret = headers.get("authorization"); 22 let secret = headers.get("authorization");
20 let authorized = auth(secret).map_err(WebolError::Auth)?; 23 let authorized = auth(&state.config, secret).map_err(Error::Auth)?;
21 if authorized { 24 if authorized {
22 let device = sqlx::query_as!( 25 let device = sqlx::query_as!(
23 Device, 26 Device,
@@ -27,18 +30,19 @@ pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap
27 WHERE id = $1; 30 WHERE id = $1;
28 "#, 31 "#,
29 payload.id 32 payload.id
30 ).fetch_one(&state.db).await.map_err(WebolError::DB)?; 33 )
34 .fetch_one(&state.db)
35 .await
36 .map_err(Error::DB)?;
31 37
32 info!("starting {}", device.id); 38 info!("starting {}", device.id);
33 39
34 let bind_addr = SETTINGS 40 let bind_addr = "0.0.0.0:0";
35 .get_string("bindaddr")
36 .unwrap_or("0.0.0.0:1111".to_string());
37 41
38 let _ = send_packet( 42 let _ = send_packet(
39 &bind_addr.parse().map_err(WebolError::IpParse)?, 43 &bind_addr.parse().map_err(Error::IpParse)?,
40 &device.broadcast_addr.parse().map_err(WebolError::IpParse)?, 44 &device.broadcast_addr.parse().map_err(Error::IpParse)?,
41 create_buffer(&device.mac)? 45 &create_buffer(&device.mac)?,
42 )?; 46 )?;
43 let dev_id = device.id.clone(); 47 let dev_id = device.id.clone();
44 let uuid = if payload.ping.is_some_and(|ping| ping) { 48 let uuid = if payload.ping.is_some_and(|ping| ping) {
@@ -49,7 +53,7 @@ pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap
49 uuid = Some(key); 53 uuid = Some(key);
50 break; 54 break;
51 } 55 }
52 }; 56 }
53 let uuid_gen = match uuid { 57 let uuid_gen = match uuid {
54 Some(u) => u, 58 Some(u) => u,
55 None => Uuid::new_v4().to_string(), 59 None => Uuid::new_v4().to_string(),
@@ -58,26 +62,46 @@ pub async fn start(State(state): State<Arc<crate::AppState>>, headers: HeaderMap
58 62
59 tokio::spawn(async move { 63 tokio::spawn(async move {
60 debug!("init ping service"); 64 debug!("init ping service");
61 state.ping_map.insert(uuid_gen.clone(), PingValue { ip: device.ip.clone(), online: false }); 65 state.ping_map.insert(
66 uuid_gen.clone(),
67 PingValue {
68 ip: device.ip.clone(),
69 online: false,
70 },
71 );
62 72
63 crate::services::ping::spawn(state.ping_send.clone(), device, uuid_gen.clone(), &state.ping_map, &state.db).await 73 crate::services::ping::spawn(
74 state.ping_send.clone(),
75 &state.config,
76 device,
77 uuid_gen.clone(),
78 &state.ping_map,
79 &state.db,
80 )
81 .await;
64 }); 82 });
65 Some(uuid_genc) 83 Some(uuid_genc)
66 } else { None }; 84 } else {
67 Ok(Json(json!(StartResponse { id: dev_id, boot: true, uuid }))) 85 None
86 };
87 Ok(Json(json!(Response {
88 id: dev_id,
89 boot: true,
90 uuid
91 })))
68 } else { 92 } else {
69 Err(WebolError::Generic) 93 Err(Error::Generic)
70 } 94 }
71} 95}
72 96
73#[derive(Deserialize)] 97#[derive(Deserialize)]
74pub struct StartPayload { 98pub struct Payload {
75 id: String, 99 id: String,
76 ping: Option<bool>, 100 ping: Option<bool>,
77} 101}
78 102
79#[derive(Serialize)] 103#[derive(Serialize)]
80struct StartResponse { 104struct Response {
81 id: String, 105 id: String,
82 boot: bool, 106 boot: bool,
83 uuid: Option<String>, 107 uuid: Option<String>,
diff --git a/src/routes/status.rs b/src/routes/status.rs
index 45f3e51..31ef996 100644
--- a/src/routes/status.rs
+++ b/src/routes/status.rs
@@ -7,4 +7,4 @@ use crate::services::ping::status_websocket;
7#[axum_macros::debug_handler] 7#[axum_macros::debug_handler]
8pub async fn status(State(state): State<Arc<AppState>>, ws: WebSocketUpgrade) -> Response { 8pub async fn status(State(state): State<Arc<AppState>>, ws: WebSocketUpgrade) -> Response {
9 ws.on_upgrade(move |socket| status_websocket(socket, state)) 9 ws.on_upgrade(move |socket| status_websocket(socket, state))
10} \ No newline at end of file 10}
diff --git a/src/services/ping.rs b/src/services/ping.rs
index c3bdced..9b164c8 100644
--- a/src/services/ping.rs
+++ b/src/services/ping.rs
@@ -2,26 +2,26 @@ use std::str::FromStr;
2use std::net::IpAddr; 2use std::net::IpAddr;
3use std::sync::Arc; 3use std::sync::Arc;
4 4
5use axum::extract::{ws::WebSocket}; 5use axum::extract::ws::WebSocket;
6use axum::extract::ws::Message; 6use axum::extract::ws::Message;
7use dashmap::DashMap; 7use dashmap::DashMap;
8use sqlx::PgPool; 8use sqlx::PgPool;
9use time::{Duration, Instant}; 9use time::{Duration, Instant};
10use tokio::sync::broadcast::{Sender}; 10use tokio::sync::broadcast::Sender;
11use tracing::{debug, error, trace}; 11use tracing::{debug, error, trace};
12use crate::AppState; 12use crate::AppState;
13use crate::config::SETTINGS; 13use crate::config::Config;
14use crate::db::Device; 14use crate::db::Device;
15 15
16pub type PingMap = DashMap<String, PingValue>; 16pub type StatusMap = DashMap<String, Value>;
17 17
18#[derive(Debug, Clone)] 18#[derive(Debug, Clone)]
19pub struct PingValue { 19pub struct Value {
20 pub ip: String, 20 pub ip: String,
21 pub online: bool 21 pub online: bool
22} 22}
23 23
24pub async fn spawn(tx: Sender<BroadcastCommands>, device: Device, uuid: String, ping_map: &PingMap, db: &PgPool) { 24pub async fn spawn(tx: Sender<BroadcastCommands>, config: &Config, device: Device, uuid: String, ping_map: &StatusMap, db: &PgPool) {
25 let timer = Instant::now(); 25 let timer = Instant::now();
26 let payload = [0; 8]; 26 let payload = [0; 8];
27 27
@@ -40,7 +40,7 @@ pub async fn spawn(tx: Sender<BroadcastCommands>, device: Device, uuid: String,
40 error!("{}", ping.to_string()); 40 error!("{}", ping.to_string());
41 msg = Some(BroadcastCommands::Error(uuid.clone())); 41 msg = Some(BroadcastCommands::Error(uuid.clone()));
42 } 42 }
43 if timer.elapsed() >= Duration::minutes(SETTINGS.get_int("pingtimeout").unwrap_or(10)) { 43 if timer.elapsed() >= Duration::minutes(config.pingtimeout) {
44 msg = Some(BroadcastCommands::Timeout(uuid.clone())); 44 msg = Some(BroadcastCommands::Timeout(uuid.clone()));
45 } 45 }
46 } else { 46 } else {
@@ -63,7 +63,7 @@ pub async fn spawn(tx: Sender<BroadcastCommands>, device: Device, uuid: String,
63 timer.elapsed().whole_seconds(), 63 timer.elapsed().whole_seconds(),
64 device.id 64 device.id
65 ).execute(db).await.unwrap(); 65 ).execute(db).await.unwrap();
66 ping_map.insert(uuid.clone(), PingValue { ip: device.ip.clone(), online: true }); 66 ping_map.insert(uuid.clone(), Value { ip: device.ip.clone(), online: true });
67 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; 67 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
68 } 68 }
69 trace!("remove {} from ping_map", uuid); 69 trace!("remove {} from ping_map", uuid);
@@ -85,17 +85,14 @@ pub async fn status_websocket(mut socket: WebSocket, state: Arc<AppState>) {
85 trace!("Search for uuid: {}", uuid); 85 trace!("Search for uuid: {}", uuid);
86 86
87 let eta = get_eta(&state.db).await; 87 let eta = get_eta(&state.db).await;
88 let _ = socket.send(Message::Text(format!("eta_{}_{}", eta, uuid))).await; 88 let _ = socket.send(Message::Text(format!("eta_{eta}_{uuid}"))).await;
89 89
90 let device_exists = state.ping_map.contains_key(&uuid); 90 let device_exists = state.ping_map.contains_key(&uuid);
91 match device_exists { 91 if device_exists {
92 true => { 92 let _ = socket.send(process_device(state.clone(), uuid).await).await;
93 let _ = socket.send(process_device(state.clone(), uuid).await).await; 93 } else {
94 }, 94 debug!("didn't find any device");
95 false => { 95 let _ = socket.send(Message::Text(format!("notfound_{uuid}"))).await;
96 debug!("didn't find any device");
97 let _ = socket.send(Message::Text(format!("notfound_{}", uuid))).await;
98 },
99 }; 96 };
100 97
101 let _ = socket.close().await; 98 let _ = socket.close().await;
@@ -110,7 +107,7 @@ async fn get_eta(db: &PgPool) -> i64 {
110 None => { vec![0] }, 107 None => { vec![0] },
111 Some(t) => t, 108 Some(t) => t,
112 }; 109 };
113 times.iter().sum::<i64>() / times.len() as i64 110 times.iter().sum::<i64>() / i64::try_from(times.len()).unwrap()
114 111
115} 112}
116 113
@@ -118,34 +115,31 @@ async fn process_device(state: Arc<AppState>, uuid: String) -> Message {
118 let pm = state.ping_map.clone().into_read_only(); 115 let pm = state.ping_map.clone().into_read_only();
119 let device = pm.get(&uuid).expect("fatal error"); 116 let device = pm.get(&uuid).expect("fatal error");
120 debug!("got device: {} (online: {})", device.ip, device.online); 117 debug!("got device: {} (online: {})", device.ip, device.online);
121 match device.online { 118 if device.online {
122 true => { 119 debug!("already started");
123 debug!("already started"); 120 Message::Text(format!("start_{uuid}"))
124 Message::Text(format!("start_{}", uuid)) 121 } else {
125 }, 122 loop {
126 false => { 123 trace!("wait for tx message");
127 loop{ 124 let message = state.ping_send.subscribe().recv().await.expect("fatal error");
128 trace!("wait for tx message"); 125 trace!("got message {:?}", message);
129 let message = state.ping_send.subscribe().recv().await.expect("fatal error"); 126 return match message {
130 trace!("got message {:?}", message); 127 BroadcastCommands::Success(msg_uuid) => {
131 return match message { 128 if msg_uuid != uuid { continue; }
132 BroadcastCommands::Success(msg_uuid) => { 129 trace!("message == uuid success");
133 if msg_uuid != uuid { continue; } 130 Message::Text(format!("start_{uuid}"))
134 trace!("message == uuid success"); 131 },
135 Message::Text(format!("start_{}", uuid)) 132 BroadcastCommands::Timeout(msg_uuid) => {
136 }, 133 if msg_uuid != uuid { continue; }
137 BroadcastCommands::Timeout(msg_uuid) => { 134 trace!("message == uuid timeout");
138 if msg_uuid != uuid { continue; } 135 Message::Text(format!("timeout_{uuid}"))
139 trace!("message == uuid timeout"); 136 },
140 Message::Text(format!("timeout_{}", uuid)) 137 BroadcastCommands::Error(msg_uuid) => {
141 }, 138 if msg_uuid != uuid { continue; }
142 BroadcastCommands::Error(msg_uuid) => { 139 trace!("message == uuid error");
143 if msg_uuid != uuid { continue; } 140 Message::Text(format!("error_{uuid}"))
144 trace!("message == uuid error");
145 Message::Text(format!("error_{}", uuid))
146 }
147 } 141 }
148 } 142 }
149 } 143 }
150 } 144 }
151} \ No newline at end of file 145}
diff --git a/src/wol.rs b/src/wol.rs
index 0cdcae3..83c0ee6 100644
--- a/src/wol.rs
+++ b/src/wol.rs
@@ -1,17 +1,17 @@
1use std::net::{SocketAddr, UdpSocket}; 1use std::net::{SocketAddr, UdpSocket};
2 2
3use crate::error::WebolError; 3use crate::error::Error;
4 4
5/// Creates the magic packet from a mac address 5/// Creates the magic packet from a mac address
6/// 6///
7/// # Panics 7/// # Panics
8/// 8///
9/// Panics if `mac_addr` is an invalid mac 9/// Panics if `mac_addr` is an invalid mac
10pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, WebolError> { 10pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, Error> {
11 let mut mac = Vec::new(); 11 let mut mac = Vec::new();
12 let sp = mac_addr.split(':'); 12 let sp = mac_addr.split(':');
13 for f in sp { 13 for f in sp {
14 mac.push(u8::from_str_radix(f, 16).map_err(WebolError::BufferParse)?) 14 mac.push(u8::from_str_radix(f, 16).map_err(Error::BufferParse)?);
15 }; 15 };
16 let mut buf = vec![255; 6]; 16 let mut buf = vec![255; 6];
17 for _ in 0..16 { 17 for _ in 0..16 {
@@ -23,8 +23,8 @@ pub fn create_buffer(mac_addr: &str) -> Result<Vec<u8>, WebolError> {
23} 23}
24 24
25/// Sends a buffer on UDP broadcast 25/// Sends a buffer on UDP broadcast
26pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: Vec<u8>) -> Result<usize, WebolError> { 26pub fn send_packet(bind_addr: &SocketAddr, broadcast_addr: &SocketAddr, buffer: &[u8]) -> Result<usize, Error> {
27 let socket = UdpSocket::bind(bind_addr).map_err(WebolError::Broadcast)?; 27 let socket = UdpSocket::bind(bind_addr).map_err(Error::Broadcast)?;
28 socket.set_broadcast(true).map_err(WebolError::Broadcast)?; 28 socket.set_broadcast(true).map_err(Error::Broadcast)?;
29 socket.send_to(&buffer, broadcast_addr).map_err(WebolError::Broadcast) 29 socket.send_to(buffer, broadcast_addr).map_err(Error::Broadcast)
30} 30}