aboutsummaryrefslogtreecommitdiff
path: root/src/routes
diff options
context:
space:
mode:
Diffstat (limited to 'src/routes')
-rw-r--r--src/routes/device.rs137
-rw-r--r--src/routes/mod.rs3
-rw-r--r--src/routes/start.rs138
-rw-r--r--src/routes/status.rs79
4 files changed, 214 insertions, 143 deletions
diff --git a/src/routes/device.rs b/src/routes/device.rs
index c85df1b..d39d98e 100644
--- a/src/routes/device.rs
+++ b/src/routes/device.rs
@@ -1,34 +1,34 @@
1use std::sync::Arc; 1use crate::db::Device;
2use crate::error::Error;
2use axum::extract::State; 3use axum::extract::State;
3use axum::Json; 4use axum::Json;
4use axum::http::HeaderMap; 5use mac_address::MacAddress;
5use serde::{Deserialize, Serialize}; 6use serde::{Deserialize, Serialize};
6use serde_json::{json, Value}; 7use serde_json::{json, Value};
8use sqlx::types::ipnetwork::IpNetwork;
9use std::{sync::Arc, str::FromStr};
7use tracing::{debug, info}; 10use tracing::{debug, info};
8use crate::auth::auth;
9use crate::db::Device;
10use crate::error::Error;
11 11
12pub async fn get(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<GetDevicePayload>) -> Result<Json<Value>, Error> { 12pub async fn get(
13 info!("add device {}", payload.id); 13 State(state): State<Arc<crate::AppState>>,
14 let secret = headers.get("authorization"); 14 Json(payload): Json<GetDevicePayload>,
15 if auth(&state.config, secret).map_err(Error::Auth)? { 15) -> Result<Json<Value>, Error> {
16 let device = sqlx::query_as!( 16 info!("get device {}", payload.id);
17 Device, 17 let device = sqlx::query_as!(
18 r#" 18 Device,
19 SELECT id, mac, broadcast_addr, ip, times 19 r#"
20 FROM devices 20 SELECT id, mac, broadcast_addr, ip, times
21 WHERE id = $1; 21 FROM devices
22 "#, 22 WHERE id = $1;
23 payload.id 23 "#,
24 ).fetch_one(&state.db).await.map_err(Error::DB)?; 24 payload.id
25 )
26 .fetch_one(&state.db)
27 .await?;
25 28
26 debug!("got device {:?}", device); 29 debug!("got device {:?}", device);
27 30
28 Ok(Json(json!(device))) 31 Ok(Json(json!(device)))
29 } else {
30 Err(Error::Generic)
31 }
32} 32}
33 33
34#[derive(Deserialize)] 34#[derive(Deserialize)]
@@ -36,25 +36,31 @@ pub struct GetDevicePayload {
36 id: String, 36 id: String,
37} 37}
38 38
39pub async fn put(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PutDevicePayload>) -> Result<Json<Value>, Error> { 39pub async fn put(
40 info!("add device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); 40 State(state): State<Arc<crate::AppState>>,
41 let secret = headers.get("authorization"); 41 Json(payload): Json<PutDevicePayload>,
42 if auth(&state.config, secret).map_err(Error::Auth)? { 42) -> Result<Json<Value>, Error> {
43 sqlx::query!( 43 info!(
44 r#" 44 "add device {} ({}, {}, {})",
45 INSERT INTO devices (id, mac, broadcast_addr, ip) 45 payload.id, payload.mac, payload.broadcast_addr, payload.ip
46 VALUES ($1, $2, $3, $4); 46 );
47 "#, 47
48 payload.id, 48 let ip = IpNetwork::from_str(&payload.ip)?;
49 payload.mac, 49 let mac = MacAddress::from_str(&payload.mac)?;
50 payload.broadcast_addr, 50 sqlx::query!(
51 payload.ip 51 r#"
52 ).execute(&state.db).await.map_err(Error::DB)?; 52 INSERT INTO devices (id, mac, broadcast_addr, ip)
53 VALUES ($1, $2, $3, $4);
54 "#,
55 payload.id,
56 mac,
57 payload.broadcast_addr,
58 ip
59 )
60 .execute(&state.db)
61 .await?;
53 62
54 Ok(Json(json!(PutDeviceResponse { success: true }))) 63 Ok(Json(json!(PutDeviceResponse { success: true })))
55 } else {
56 Err(Error::Generic)
57 }
58} 64}
59 65
60#[derive(Deserialize)] 66#[derive(Deserialize)]
@@ -62,35 +68,40 @@ pub struct PutDevicePayload {
62 id: String, 68 id: String,
63 mac: String, 69 mac: String,
64 broadcast_addr: String, 70 broadcast_addr: String,
65 ip: String 71 ip: String,
66} 72}
67 73
68#[derive(Serialize)] 74#[derive(Serialize)]
69pub struct PutDeviceResponse { 75pub struct PutDeviceResponse {
70 success: bool 76 success: bool,
71} 77}
72 78
73pub async fn post(State(state): State<Arc<crate::AppState>>, headers: HeaderMap, Json(payload): Json<PostDevicePayload>) -> Result<Json<Value>, Error> { 79pub async fn post(
74 info!("edit device {} ({}, {}, {})", payload.id, payload.mac, payload.broadcast_addr, payload.ip); 80 State(state): State<Arc<crate::AppState>>,
75 let secret = headers.get("authorization"); 81 Json(payload): Json<PostDevicePayload>,
76 if auth(&state.config, secret).map_err(Error::Auth)? { 82) -> Result<Json<Value>, Error> {
77 let device = sqlx::query_as!( 83 info!(
78 Device, 84 "edit device {} ({}, {}, {})",
79 r#" 85 payload.id, payload.mac, payload.broadcast_addr, payload.ip
80 UPDATE devices 86 );
81 SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 87 let ip = IpNetwork::from_str(&payload.ip)?;
82 RETURNING id, mac, broadcast_addr, ip, times; 88 let mac = MacAddress::from_str(&payload.mac)?;
83 "#, 89 let device = sqlx::query_as!(
84 payload.mac, 90 Device,
85 payload.broadcast_addr, 91 r#"
86 payload.ip, 92 UPDATE devices
87 payload.id 93 SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4
88 ).fetch_one(&state.db).await.map_err(Error::DB)?; 94 RETURNING id, mac, broadcast_addr, ip, times;
95 "#,
96 mac,
97 payload.broadcast_addr,
98 ip,
99 payload.id
100 )
101 .fetch_one(&state.db)
102 .await?;
89 103
90 Ok(Json(json!(device))) 104 Ok(Json(json!(device)))
91 } else {
92 Err(Error::Generic)
93 }
94} 105}
95 106
96#[derive(Deserialize)] 107#[derive(Deserialize)]
diff --git a/src/routes/mod.rs b/src/routes/mod.rs
deleted file mode 100644
index d5ab0d6..0000000
--- a/src/routes/mod.rs
+++ /dev/null
@@ -1,3 +0,0 @@
1pub mod start;
2pub mod device;
3pub mod status; \ No newline at end of file
diff --git a/src/routes/start.rs b/src/routes/start.rs
index ce95bf3..d4c0802 100644
--- a/src/routes/start.rs
+++ b/src/routes/start.rs
@@ -1,10 +1,8 @@
1use crate::auth::auth;
2use crate::db::Device; 1use crate::db::Device;
3use crate::error::Error; 2use crate::error::Error;
4use crate::services::ping::Value as PingValue; 3use crate::services::ping::Value as PingValue;
5use crate::wol::{create_buffer, send_packet}; 4use crate::wol::{create_buffer, send_packet};
6use axum::extract::State; 5use axum::extract::State;
7use axum::http::HeaderMap;
8use axum::Json; 6use axum::Json;
9use serde::{Deserialize, Serialize}; 7use serde::{Deserialize, Serialize};
10use serde_json::{json, Value}; 8use serde_json::{json, Value};
@@ -12,86 +10,82 @@ use std::sync::Arc;
12use tracing::{debug, info}; 10use tracing::{debug, info};
13use uuid::Uuid; 11use uuid::Uuid;
14 12
15#[axum_macros::debug_handler]
16pub async fn start( 13pub async fn start(
17 State(state): State<Arc<crate::AppState>>, 14 State(state): State<Arc<crate::AppState>>,
18 headers: HeaderMap,
19 Json(payload): Json<Payload>, 15 Json(payload): Json<Payload>,
20) -> Result<Json<Value>, Error> { 16) -> Result<Json<Value>, Error> {
21 info!("POST request"); 17 info!("POST request");
22 let secret = headers.get("authorization"); 18 let device = sqlx::query_as!(
23 let authorized = auth(&state.config, secret).map_err(Error::Auth)?; 19 Device,
24 if authorized { 20 r#"
25 let device = sqlx::query_as!( 21 SELECT id, mac, broadcast_addr, ip, times
26 Device, 22 FROM devices
27 r#" 23 WHERE id = $1;
28 SELECT id, mac, broadcast_addr, ip, times 24 "#,
29 FROM devices 25 payload.id
30 WHERE id = $1; 26 )
31 "#, 27 .fetch_one(&state.db)
32 payload.id 28 .await?;
33 )
34 .fetch_one(&state.db)
35 .await
36 .map_err(Error::DB)?;
37
38 info!("starting {}", device.id);
39
40 let bind_addr = "0.0.0.0:0";
41 29
42 let _ = send_packet( 30 info!("starting {}", device.id);
43 &bind_addr.parse().map_err(Error::IpParse)?,
44 &device.broadcast_addr.parse().map_err(Error::IpParse)?,
45 &create_buffer(&device.mac)?,
46 )?;
47 let dev_id = device.id.clone();
48 let uuid = if payload.ping.is_some_and(|ping| ping) {
49 let mut uuid: Option<String> = None;
50 for (key, value) in state.ping_map.clone() {
51 if value.ip == device.ip {
52 debug!("service already exists");
53 uuid = Some(key);
54 break;
55 }
56 }
57 let uuid_gen = match uuid {
58 Some(u) => u,
59 None => Uuid::new_v4().to_string(),
60 };
61 let uuid_genc = uuid_gen.clone();
62 31
63 tokio::spawn(async move { 32 let bind_addr = "0.0.0.0:0";
64 debug!("init ping service");
65 state.ping_map.insert(
66 uuid_gen.clone(),
67 PingValue {
68 ip: device.ip.clone(),
69 online: false,
70 },
71 );
72 33
73 crate::services::ping::spawn( 34 let _ = send_packet(
74 state.ping_send.clone(), 35 bind_addr,
75 &state.config, 36 &device.broadcast_addr,
76 device, 37 &create_buffer(&device.mac.to_string())?,
77 uuid_gen.clone(), 38 )?;
78 &state.ping_map, 39 let dev_id = device.id.clone();
79 &state.db, 40 let uuid = if payload.ping.is_some_and(|ping| ping) {
80 ) 41 Some(setup_ping(state, device))
81 .await;
82 });
83 Some(uuid_genc)
84 } else {
85 None
86 };
87 Ok(Json(json!(Response {
88 id: dev_id,
89 boot: true,
90 uuid
91 })))
92 } else { 42 } else {
93 Err(Error::Generic) 43 None
44 };
45 Ok(Json(json!(Response {
46 id: dev_id,
47 boot: true,
48 uuid
49 })))
50}
51
52fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String {
53 let mut uuid: Option<String> = None;
54 for (key, value) in state.ping_map.clone() {
55 if value.ip == device.ip {
56 debug!("service already exists");
57 uuid = Some(key);
58 break;
59 }
94 } 60 }
61 let uuid_gen = match uuid {
62 Some(u) => u,
63 None => Uuid::new_v4().to_string(),
64 };
65 let uuid_ret = uuid_gen.clone();
66
67 debug!("init ping service");
68 state.ping_map.insert(
69 uuid_gen.clone(),
70 PingValue {
71 ip: device.ip,
72 online: false,
73 },
74 );
75
76 tokio::spawn(async move {
77 crate::services::ping::spawn(
78 state.ping_send.clone(),
79 &state.config,
80 device,
81 uuid_gen,
82 &state.ping_map,
83 &state.db,
84 )
85 .await;
86 });
87
88 uuid_ret
95} 89}
96 90
97#[derive(Deserialize)] 91#[derive(Deserialize)]
diff --git a/src/routes/status.rs b/src/routes/status.rs
index 31ef996..0e25f7d 100644
--- a/src/routes/status.rs
+++ b/src/routes/status.rs
@@ -1,10 +1,79 @@
1use std::sync::Arc; 1use crate::services::ping::BroadcastCommand;
2use crate::AppState;
3use axum::extract::ws::{Message, WebSocket};
2use axum::extract::{State, WebSocketUpgrade}; 4use axum::extract::{State, WebSocketUpgrade};
3use axum::response::Response; 5use axum::response::Response;
4use crate::AppState; 6use sqlx::PgPool;
5use crate::services::ping::status_websocket; 7use std::sync::Arc;
8use tracing::{debug, trace};
6 9
7#[axum_macros::debug_handler]
8pub async fn status(State(state): State<Arc<AppState>>, ws: WebSocketUpgrade) -> Response { 10pub async fn status(State(state): State<Arc<AppState>>, ws: WebSocketUpgrade) -> Response {
9 ws.on_upgrade(move |socket| status_websocket(socket, state)) 11 ws.on_upgrade(move |socket| websocket(socket, state))
12}
13
14pub async fn websocket(mut socket: WebSocket, state: Arc<AppState>) {
15 trace!("wait for ws message (uuid)");
16 let msg = socket.recv().await;
17 let uuid = msg.unwrap().unwrap().into_text().unwrap();
18
19 trace!("Search for uuid: {}", uuid);
20
21 let eta = get_eta(&state.db).await;
22 let _ = socket
23 .send(Message::Text(format!("eta_{eta}_{uuid}")))
24 .await;
25
26 let device_exists = state.ping_map.contains_key(&uuid);
27 if device_exists {
28 let _ = socket
29 .send(receive_ping_broadcast(state.clone(), uuid).await)
30 .await;
31 } else {
32 debug!("didn't find any device");
33 let _ = socket.send(Message::Text(format!("notfound_{uuid}"))).await;
34 };
35
36 let _ = socket.close().await;
37}
38
39async fn receive_ping_broadcast(state: Arc<AppState>, uuid: String) -> Message {
40 let pm = state.ping_map.clone().into_read_only();
41 let device = pm.get(&uuid).expect("fatal error");
42 debug!("got device: {} (online: {})", device.ip, device.online);
43 if device.online {
44 debug!("already started");
45 Message::Text(BroadcastCommand::success(uuid).to_string())
46 } else {
47 loop {
48 trace!("wait for tx message");
49 let message = state
50 .ping_send
51 .subscribe()
52 .recv()
53 .await
54 .expect("fatal error");
55 trace!("got message {:?}", message);
56
57 if message.uuid != uuid {
58 continue;
59 }
60 trace!("message == uuid success");
61 return Message::Text(message.to_string());
62 }
63 }
64}
65
66async fn get_eta(db: &PgPool) -> i64 {
67 let query = sqlx::query!(r#"SELECT times FROM devices;"#)
68 .fetch_one(db)
69 .await
70 .unwrap();
71
72 let times = if let Some(times) = query.times {
73 times
74 } else {
75 vec![0]
76 };
77
78 times.iter().sum::<i64>() / i64::try_from(times.len()).unwrap()
10} 79}