summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/auth.rs30
-rw-r--r--src/error.rs6
-rw-r--r--src/extractors.rs24
-rw-r--r--src/main.rs20
-rw-r--r--src/routes/device.rs113
-rw-r--r--src/routes/start.rs65
6 files changed, 107 insertions, 151 deletions
diff --git a/src/auth.rs b/src/auth.rs
deleted file mode 100644
index 22f87e7..0000000
--- a/src/auth.rs
+++ /dev/null
@@ -1,30 +0,0 @@
1use axum::http::HeaderValue;
2use tracing::{debug, trace};
3use crate::config::Config;
4use crate::error::Error;
5
6pub fn auth(config: &Config, secret: Option<&HeaderValue>) -> Result<Response, Error> {
7 debug!("auth request with secret {:?}", secret);
8 let res = if let Some(value) = secret {
9 trace!("auth value exists");
10 let key = &config.apikey;
11 if value.to_str()? == key.as_str() {
12 debug!("successful auth");
13 Response::Success
14 } else {
15 debug!("unsuccessful auth (wrong secret)");
16 Response::WrongSecret
17 }
18 } else {
19 debug!("unsuccessful auth (no secret)");
20 Response::MissingSecret
21 };
22 Ok(res)
23}
24
25#[derive(Debug)]
26pub enum Response {
27 Success,
28 WrongSecret,
29 MissingSecret
30}
diff --git a/src/error.rs b/src/error.rs
index 66a61f4..513b51b 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -1,8 +1,8 @@
1use ::ipnetwork::IpNetworkError;
1use axum::http::header::ToStrError; 2use axum::http::header::ToStrError;
2use axum::http::StatusCode; 3use axum::http::StatusCode;
3use axum::response::{IntoResponse, Response}; 4use axum::response::{IntoResponse, Response};
4use axum::Json; 5use axum::Json;
5use ::ipnetwork::IpNetworkError;
6use mac_address::MacParseError; 6use mac_address::MacParseError;
7use serde_json::json; 7use serde_json::json;
8use std::io; 8use std::io;
@@ -10,9 +10,6 @@ use tracing::error;
10 10
11#[derive(Debug, thiserror::Error)] 11#[derive(Debug, thiserror::Error)]
12pub enum Error { 12pub enum Error {
13 #[error("generic error")]
14 Generic,
15
16 #[error("db: {source}")] 13 #[error("db: {source}")]
17 Db { 14 Db {
18 #[from] 15 #[from]
@@ -54,7 +51,6 @@ impl IntoResponse for Error {
54 fn into_response(self) -> Response { 51 fn into_response(self) -> Response {
55 error!("{}", self.to_string()); 52 error!("{}", self.to_string());
56 let (status, error_message) = match self { 53 let (status, error_message) = match self {
57 Self::Generic => (StatusCode::INTERNAL_SERVER_ERROR, ""),
58 Self::Db { source } => { 54 Self::Db { source } => {
59 error!("{source}"); 55 error!("{source}");
60 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error") 56 (StatusCode::INTERNAL_SERVER_ERROR, "Server Error")
diff --git a/src/extractors.rs b/src/extractors.rs
new file mode 100644
index 0000000..4d441e9
--- /dev/null
+++ b/src/extractors.rs
@@ -0,0 +1,24 @@
1use axum::{
2 extract::{Request, State},
3 http::{HeaderMap, StatusCode},
4 middleware::Next,
5 response::Response,
6};
7
8use crate::AppState;
9
10pub async fn auth(
11 State(state): State<AppState>,
12 headers: HeaderMap,
13 request: Request,
14 next: Next,
15) -> Result<Response, StatusCode> {
16 let secret = headers.get("authorization");
17 match secret {
18 Some(token) if token == state.config.apikey.as_str() => {
19 let response = next.run(request).await;
20 Ok(response)
21 }
22 _ => Err(StatusCode::UNAUTHORIZED),
23 }
24}
diff --git a/src/main.rs b/src/main.rs
index 7d8c1da..eae89f6 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -4,26 +4,23 @@ use crate::routes::device;
4use crate::routes::start::start; 4use crate::routes::start::start;
5use crate::routes::status::status; 5use crate::routes::status::status;
6use crate::services::ping::StatusMap; 6use crate::services::ping::StatusMap;
7use axum::middleware::from_fn_with_state;
7use axum::routing::{get, put}; 8use axum::routing::{get, put};
8use axum::{routing::post, Router}; 9use axum::{routing::post, Router};
9use dashmap::DashMap; 10use dashmap::DashMap;
10use services::ping::BroadcastCommand; 11use services::ping::BroadcastCommand;
11use sqlx::PgPool; 12use sqlx::PgPool;
12use tracing_subscriber::fmt::time::UtcTime;
13use std::env; 13use std::env;
14use std::sync::Arc; 14use std::sync::Arc;
15use tokio::sync::broadcast::{channel, Sender}; 15use tokio::sync::broadcast::{channel, Sender};
16use tracing::{info, level_filters::LevelFilter}; 16use tracing::{info, level_filters::LevelFilter};
17use tracing_subscriber::{ 17use tracing_subscriber::fmt::time::UtcTime;
18 fmt, 18use tracing_subscriber::{fmt, prelude::*, EnvFilter};
19 prelude::*,
20 EnvFilter,
21};
22 19
23mod auth;
24mod config; 20mod config;
25mod db; 21mod db;
26mod error; 22mod error;
23mod extractors;
27mod routes; 24mod routes;
28mod services; 25mod services;
29mod wol; 26mod wol;
@@ -31,7 +28,6 @@ mod wol;
31#[tokio::main] 28#[tokio::main]
32async fn main() -> color_eyre::eyre::Result<()> { 29async fn main() -> color_eyre::eyre::Result<()> {
33 color_eyre::install()?; 30 color_eyre::install()?;
34
35 31
36 let time_format = 32 let time_format =
37 time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]"); 33 time::macros::format_description!("[year]-[month]-[day] [hour]:[minute]:[second]");
@@ -59,12 +55,12 @@ async fn main() -> color_eyre::eyre::Result<()> {
59 55
60 let ping_map: StatusMap = DashMap::new(); 56 let ping_map: StatusMap = DashMap::new();
61 57
62 let shared_state = Arc::new(AppState { 58 let shared_state = AppState {
63 db, 59 db,
64 config: config.clone(), 60 config: config.clone(),
65 ping_send: tx, 61 ping_send: tx,
66 ping_map, 62 ping_map,
67 }); 63 };
68 64
69 let app = Router::new() 65 let app = Router::new()
70 .route("/start", post(start)) 66 .route("/start", post(start))
@@ -72,7 +68,8 @@ async fn main() -> color_eyre::eyre::Result<()> {
72 .route("/device", put(device::put)) 68 .route("/device", put(device::put))
73 .route("/device", post(device::post)) 69 .route("/device", post(device::post))
74 .route("/status", get(status)) 70 .route("/status", get(status))
75 .with_state(shared_state); 71 .route_layer(from_fn_with_state(shared_state.clone(), extractors::auth))
72 .with_state(Arc::new(shared_state));
76 73
77 let addr = config.serveraddr; 74 let addr = config.serveraddr;
78 info!("start server on {}", addr); 75 info!("start server on {}", addr);
@@ -82,6 +79,7 @@ async fn main() -> color_eyre::eyre::Result<()> {
82 Ok(()) 79 Ok(())
83} 80}
84 81
82#[derive(Clone)]
85pub struct AppState { 83pub struct AppState {
86 db: PgPool, 84 db: PgPool,
87 config: Config, 85 config: Config,
diff --git a/src/routes/device.rs b/src/routes/device.rs
index 2f0093d..d39d98e 100644
--- a/src/routes/device.rs
+++ b/src/routes/device.rs
@@ -1,8 +1,6 @@
1use crate::auth::auth;
2use crate::db::Device; 1use crate::db::Device;
3use crate::error::Error; 2use crate::error::Error;
4use axum::extract::State; 3use axum::extract::State;
5use axum::http::HeaderMap;
6use axum::Json; 4use axum::Json;
7use mac_address::MacAddress; 5use mac_address::MacAddress;
8use serde::{Deserialize, Serialize}; 6use serde::{Deserialize, Serialize};
@@ -13,31 +11,24 @@ use tracing::{debug, info};
13 11
14pub async fn get( 12pub async fn get(
15 State(state): State<Arc<crate::AppState>>, 13 State(state): State<Arc<crate::AppState>>,
16 headers: HeaderMap,
17 Json(payload): Json<GetDevicePayload>, 14 Json(payload): Json<GetDevicePayload>,
18) -> Result<Json<Value>, Error> { 15) -> Result<Json<Value>, Error> {
19 info!("get device {}", payload.id); 16 info!("get device {}", payload.id);
20 let secret = headers.get("authorization"); 17 let device = sqlx::query_as!(
21 let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); 18 Device,
22 if authorized { 19 r#"
23 let device = sqlx::query_as!( 20 SELECT id, mac, broadcast_addr, ip, times
24 Device, 21 FROM devices
25 r#" 22 WHERE id = $1;
26 SELECT id, mac, broadcast_addr, ip, times 23 "#,
27 FROM devices 24 payload.id
28 WHERE id = $1; 25 )
29 "#, 26 .fetch_one(&state.db)
30 payload.id 27 .await?;
31 )
32 .fetch_one(&state.db)
33 .await?;
34 28
35 debug!("got device {:?}", device); 29 debug!("got device {:?}", device);
36 30
37 Ok(Json(json!(device))) 31 Ok(Json(json!(device)))
38 } else {
39 Err(Error::Generic)
40 }
41} 32}
42 33
43#[derive(Deserialize)] 34#[derive(Deserialize)]
@@ -47,7 +38,6 @@ pub struct GetDevicePayload {
47 38
48pub async fn put( 39pub async fn put(
49 State(state): State<Arc<crate::AppState>>, 40 State(state): State<Arc<crate::AppState>>,
50 headers: HeaderMap,
51 Json(payload): Json<PutDevicePayload>, 41 Json(payload): Json<PutDevicePayload>,
52) -> Result<Json<Value>, Error> { 42) -> Result<Json<Value>, Error> {
53 info!( 43 info!(
@@ -55,28 +45,22 @@ pub async fn put(
55 payload.id, payload.mac, payload.broadcast_addr, payload.ip 45 payload.id, payload.mac, payload.broadcast_addr, payload.ip
56 ); 46 );
57 47
58 let secret = headers.get("authorization"); 48 let ip = IpNetwork::from_str(&payload.ip)?;
59 let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); 49 let mac = MacAddress::from_str(&payload.mac)?;
60 if authorized { 50 sqlx::query!(
61 let ip = IpNetwork::from_str(&payload.ip)?; 51 r#"
62 let mac = MacAddress::from_str(&payload.mac)?; 52 INSERT INTO devices (id, mac, broadcast_addr, ip)
63 sqlx::query!( 53 VALUES ($1, $2, $3, $4);
64 r#" 54 "#,
65 INSERT INTO devices (id, mac, broadcast_addr, ip) 55 payload.id,
66 VALUES ($1, $2, $3, $4); 56 mac,
67 "#, 57 payload.broadcast_addr,
68 payload.id, 58 ip
69 mac, 59 )
70 payload.broadcast_addr, 60 .execute(&state.db)
71 ip 61 .await?;
72 )
73 .execute(&state.db)
74 .await?;
75 62
76 Ok(Json(json!(PutDeviceResponse { success: true }))) 63 Ok(Json(json!(PutDeviceResponse { success: true })))
77 } else {
78 Err(Error::Generic)
79 }
80} 64}
81 65
82#[derive(Deserialize)] 66#[derive(Deserialize)]
@@ -94,37 +78,30 @@ pub struct PutDeviceResponse {
94 78
95pub async fn post( 79pub async fn post(
96 State(state): State<Arc<crate::AppState>>, 80 State(state): State<Arc<crate::AppState>>,
97 headers: HeaderMap,
98 Json(payload): Json<PostDevicePayload>, 81 Json(payload): Json<PostDevicePayload>,
99) -> Result<Json<Value>, Error> { 82) -> Result<Json<Value>, Error> {
100 info!( 83 info!(
101 "edit device {} ({}, {}, {})", 84 "edit device {} ({}, {}, {})",
102 payload.id, payload.mac, payload.broadcast_addr, payload.ip 85 payload.id, payload.mac, payload.broadcast_addr, payload.ip
103 ); 86 );
104 let secret = headers.get("authorization"); 87 let ip = IpNetwork::from_str(&payload.ip)?;
105 let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); 88 let mac = MacAddress::from_str(&payload.mac)?;
106 if authorized { 89 let device = sqlx::query_as!(
107 let ip = IpNetwork::from_str(&payload.ip)?; 90 Device,
108 let mac = MacAddress::from_str(&payload.mac)?; 91 r#"
109 let device = sqlx::query_as!( 92 UPDATE devices
110 Device, 93 SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4
111 r#" 94 RETURNING id, mac, broadcast_addr, ip, times;
112 UPDATE devices 95 "#,
113 SET mac = $1, broadcast_addr = $2, ip = $3 WHERE id = $4 96 mac,
114 RETURNING id, mac, broadcast_addr, ip, times; 97 payload.broadcast_addr,
115 "#, 98 ip,
116 mac, 99 payload.id
117 payload.broadcast_addr, 100 )
118 ip, 101 .fetch_one(&state.db)
119 payload.id 102 .await?;
120 )
121 .fetch_one(&state.db)
122 .await?;
123 103
124 Ok(Json(json!(device))) 104 Ok(Json(json!(device)))
125 } else {
126 Err(Error::Generic)
127 }
128} 105}
129 106
130#[derive(Deserialize)] 107#[derive(Deserialize)]
diff --git a/src/routes/start.rs b/src/routes/start.rs
index 4888325..d4c0802 100644
--- a/src/routes/start.rs
+++ b/src/routes/start.rs
@@ -1,10 +1,8 @@
1use crate::auth::auth;
2use crate::db::Device; 1use crate::db::Device;
3use crate::error::Error; 2use crate::error::Error;
4use crate::services::ping::Value as PingValue; 3use crate::services::ping::Value as PingValue;
5use crate::wol::{create_buffer, send_packet}; 4use crate::wol::{create_buffer, send_packet};
6use axum::extract::State; 5use axum::extract::State;
7use axum::http::HeaderMap;
8use axum::Json; 6use axum::Json;
9use serde::{Deserialize, Serialize}; 7use serde::{Deserialize, Serialize};
10use serde_json::{json, Value}; 8use serde_json::{json, Value};
@@ -14,48 +12,41 @@ use uuid::Uuid;
14 12
15pub async fn start( 13pub async fn start(
16 State(state): State<Arc<crate::AppState>>, 14 State(state): State<Arc<crate::AppState>>,
17 headers: HeaderMap,
18 Json(payload): Json<Payload>, 15 Json(payload): Json<Payload>,
19) -> Result<Json<Value>, Error> { 16) -> Result<Json<Value>, Error> {
20 info!("POST request"); 17 info!("POST request");
21 let secret = headers.get("authorization"); 18 let device = sqlx::query_as!(
22 let authorized = matches!(auth(&state.config, secret)?, crate::auth::Response::Success); 19 Device,
23 if authorized { 20 r#"
24 let device = sqlx::query_as!( 21 SELECT id, mac, broadcast_addr, ip, times
25 Device, 22 FROM devices
26 r#" 23 WHERE id = $1;
27 SELECT id, mac, broadcast_addr, ip, times 24 "#,
28 FROM devices 25 payload.id
29 WHERE id = $1; 26 )
30 "#, 27 .fetch_one(&state.db)
31 payload.id 28 .await?;
32 )
33 .fetch_one(&state.db)
34 .await?;
35 29
36 info!("starting {}", device.id); 30 info!("starting {}", device.id);
37 31
38 let bind_addr = "0.0.0.0:0"; 32 let bind_addr = "0.0.0.0:0";
39 33
40 let _ = send_packet( 34 let _ = send_packet(
41 bind_addr, 35 bind_addr,
42 &device.broadcast_addr, 36 &device.broadcast_addr,
43 &create_buffer(&device.mac.to_string())?, 37 &create_buffer(&device.mac.to_string())?,
44 )?; 38 )?;
45 let dev_id = device.id.clone(); 39 let dev_id = device.id.clone();
46 let uuid = if payload.ping.is_some_and(|ping| ping) { 40 let uuid = if payload.ping.is_some_and(|ping| ping) {
47 Some(setup_ping(state, device)) 41 Some(setup_ping(state, device))
48 } else {
49 None
50 };
51 Ok(Json(json!(Response {
52 id: dev_id,
53 boot: true,
54 uuid
55 })))
56 } else { 42 } else {
57 Err(Error::Generic) 43 None
58 } 44 };
45 Ok(Json(json!(Response {
46 id: dev_id,
47 boot: true,
48 uuid
49 })))
59} 50}
60 51
61fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String { 52fn setup_ping(state: Arc<crate::AppState>, device: Device) -> String {