From 03bea24f9de698375033af92a08762446d0e20cc Mon Sep 17 00:00:00 2001 From: FxQnLr Date: Sun, 25 Feb 2024 16:14:56 +0100 Subject: Closes #2. Config and setup stuff --- src/config.rs | 30 +++++++------- src/error.rs | 10 ++--- src/main.rs | 84 +++++++++++++++++++++------------------- src/requests.rs | 2 + src/requests/device.rs | 62 ++++++++++++++--------------- src/requests/mod.rs | 2 - src/requests/start.rs | 103 ++++++++++++++++++++++++++++--------------------- 7 files changed, 158 insertions(+), 135 deletions(-) create mode 100644 src/requests.rs delete mode 100644 src/requests/mod.rs (limited to 'src') diff --git a/src/config.rs b/src/config.rs index 9a9e44b..78795a3 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,19 +1,19 @@ -use config::Config; -use once_cell::sync::Lazy; +use serde::Deserialize; -pub static SETTINGS: Lazy = Lazy::new(setup); - -fn setup() -> Config { - #[cfg(not(debug_assertions))] - let builder = Config::builder().add_source(config::File::with_name( - format!("{}/webol-cli.toml", dirs::config_dir().unwrap().to_string_lossy()).as_str(), - )); +#[derive(Deserialize)] +pub struct Config { + pub apikey: String, + pub server: String, +} - #[cfg(debug_assertions)] - let builder = Config::builder().add_source(config::File::with_name("webol-cli.toml")); +impl Config { + pub fn load() -> Result { + let builder = config::Config::builder() + .add_source(config::File::with_name("~/.config/webol-cli.toml")) + .add_source(config::File::with_name("webol-cli.toml")) + .add_source(config::Environment::with_prefix("WEBOL_CLI_").separator("_")) + .build()?; - builder - .add_source(config::Environment::with_prefix("WEBOL_CLI_").separator("_")) - .build() - .unwrap() + builder.try_deserialize() + } } diff --git a/src/error.rs b/src/error.rs index f15c60a..531528f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,11 +11,11 @@ pub enum CliError { impl Debug for CliError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Reqwest(err) => { err.fmt(f) }, - Self::Config(err) => { err.fmt(f) }, - Self::Serde(err) => { err.fmt(f) }, - Self::Parse(err) => { err.fmt(f) }, - Self::WsResponse => { f.write_str("Error in Response") }, + Self::Reqwest(err) => err.fmt(f), + Self::Config(err) => err.fmt(f), + Self::Serde(err) => err.fmt(f), + Self::Parse(err) => err.fmt(f), + Self::WsResponse => f.write_str("Error in Response"), } } } diff --git a/src/main.rs b/src/main.rs index afe6fac..0393183 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,11 @@ use std::{fmt::Display, time::Duration}; -use clap::{Parser, Command, CommandFactory, Subcommand}; -use clap_complete::{generate, Shell, Generator}; -use config::SETTINGS; +use crate::config::Config; +use clap::{Command, CommandFactory, Parser, Subcommand}; +use clap_complete::{generate, Generator, Shell}; use error::CliError; -use indicatif::{ProgressBar, ProgressStyle, MultiProgress}; -use requests::{start::start, device}; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; +use requests::{device, start::start}; use reqwest::header::{HeaderMap, HeaderValue}; use serde::Deserialize; @@ -35,7 +35,7 @@ enum Commands { /// id of the device id: String, #[arg(short, long)] - ping: Option + ping: Option, }, Device { #[command(subcommand)] @@ -52,7 +52,7 @@ enum DeviceCmd { id: String, mac: String, broadcast_addr: String, - ip: String + ip: String, }, Get { id: String, @@ -61,29 +61,39 @@ enum DeviceCmd { id: String, mac: String, broadcast_addr: String, - ip: String + ip: String, }, } #[tokio::main] async fn main() -> Result<(), CliError> { + let config = Config::load().map_err(CliError::Config)?; + let cli = Args::parse(); match cli.commands { Commands::Start { id, ping } => { - start(id, ping.unwrap_or(true)).await?; - }, - Commands::Device { devicecmd } => { - match devicecmd { - DeviceCmd::Add { id, mac, broadcast_addr, ip } => { - device::put(id, mac, broadcast_addr, ip).await?; - }, - DeviceCmd::Get { id } => { - device::get(id).await?; - }, - DeviceCmd::Edit { id, mac, broadcast_addr, ip } => { - device::post(id, mac, broadcast_addr, ip).await?; - }, + start(&config, id, ping.unwrap_or(true)).await?; + } + Commands::Device { devicecmd } => match devicecmd { + DeviceCmd::Add { + id, + mac, + broadcast_addr, + ip, + } => { + device::put(&config, id, mac, broadcast_addr, ip).await?; + } + DeviceCmd::Get { id } => { + device::get(&config, id).await?; + } + DeviceCmd::Edit { + id, + mac, + broadcast_addr, + ip, + } => { + device::post(&config, id, mac, broadcast_addr, ip).await?; } }, Commands::CliGen { id } => { @@ -100,29 +110,26 @@ fn print_completions(gen: G, cmd: &mut Command) { generate(gen, cmd, cmd.get_name().to_string(), &mut std::io::stdout()); } -fn default_headers() -> Result { +fn default_headers(config: &Config) -> Result { let mut map = HeaderMap::new(); - map.append("Accept-Content", HeaderValue::from_str("application/json").unwrap()); - map.append("Content-Type", HeaderValue::from_str("application/json").unwrap()); + map.append( + "Accept-Content", + HeaderValue::from_str("application/json").unwrap(), + ); + map.append( + "Content-Type", + HeaderValue::from_str("application/json").unwrap(), + ); map.append( "Authorization", - HeaderValue::from_str( - SETTINGS.get_string("key") - .map_err(CliError::Config)? - .as_str() - ).unwrap() + HeaderValue::from_str(&config.apikey).unwrap(), ); Ok(map) } -fn format_url(path: &str, protocol: Protocols) -> Result { - Ok(format!( - "{}://{}/{}", - protocol, - SETTINGS.get_string("server").map_err(CliError::Config)?, - path - )) +fn format_url(config: &Config, path: &str, protocol: Protocols) -> Result { + Ok(format!("{}://{}/{}", protocol, config.server, path)) } fn add_pb(mp: &MultiProgress, template: &str, message: String) -> ProgressBar { @@ -137,7 +144,6 @@ fn add_pb(mp: &MultiProgress, template: &str, message: String) -> ProgressBar { fn finish_pb(pb: ProgressBar, message: String, template: &str) { pb.set_style(ProgressStyle::with_template(template).unwrap()); pb.finish_with_message(message); - } enum Protocols { @@ -149,12 +155,12 @@ impl Display for Protocols { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Http => f.write_str("http"), - Self::Websocket => f.write_str("ws") + Self::Websocket => f.write_str("ws"), } } } #[derive(Debug, Deserialize)] struct ErrorResponse { - error: String + error: String, } diff --git a/src/requests.rs b/src/requests.rs new file mode 100644 index 0000000..6855db1 --- /dev/null +++ b/src/requests.rs @@ -0,0 +1,2 @@ +pub mod start; +pub mod device; diff --git a/src/requests/device.rs b/src/requests/device.rs index cbc838e..5003c4a 100644 --- a/src/requests/device.rs +++ b/src/requests/device.rs @@ -1,20 +1,21 @@ -use crate::{error::CliError, default_headers, format_url, Protocols}; +use crate::{config::Config, default_headers, error::CliError, format_url, Protocols}; -pub async fn put(id: String, mac: String, broadcast_addr: String, ip: String) -> Result<(), CliError> { - let url = format_url("device", Protocols::Http)?; +pub async fn put( + config: &Config, + id: String, + mac: String, + broadcast_addr: String, + ip: String, +) -> Result<(), CliError> { + let url = format_url(config, "device", Protocols::Http)?; println!("{}", url); let res = reqwest::Client::new() .put(url) - .headers(default_headers()?) - .body( - format!( - r#"{{"id": "{}", "mac": "{}", "broadcast_addr": "{}", "ip": "{}"}}"#, - id, - mac, - broadcast_addr, - ip - ) - ) + .headers(default_headers(config)?) + .body(format!( + r#"{{"id": "{}", "mac": "{}", "broadcast_addr": "{}", "ip": "{}"}}"#, + id, mac, broadcast_addr, ip + )) .send() .await .map_err(CliError::Reqwest)? @@ -25,13 +26,11 @@ pub async fn put(id: String, mac: String, broadcast_addr: String, ip: String) -> Ok(()) } -pub async fn get(id: String) -> Result<(), CliError> { +pub async fn get(config: &Config, id: String) -> Result<(), CliError> { let res = reqwest::Client::new() - .get(format_url("device", Protocols::Http)?) - .headers(default_headers()?) - .body( - format!(r#"{{"id": "{}"}}"#, id) - ) + .get(format_url(config, "device", Protocols::Http)?) + .headers(default_headers(config)?) + .body(format!(r#"{{"id": "{}"}}"#, id)) .send() .await .map_err(CliError::Reqwest)? @@ -42,19 +41,20 @@ pub async fn get(id: String) -> Result<(), CliError> { Ok(()) } -pub async fn post(id: String, mac: String, broadcast_addr: String, ip: String) -> Result<(), CliError> { +pub async fn post( + config: &Config, + id: String, + mac: String, + broadcast_addr: String, + ip: String, +) -> Result<(), CliError> { let res = reqwest::Client::new() - .post(format_url("device", Protocols::Http)?) - .headers(default_headers()?) - .body( - format!( - r#"{{"id": "{}", "mac": "{}", "broadcast_addr": "{}", "ip": "{}"}}"#, - id, - mac, - broadcast_addr, - ip - ) - ) + .post(format_url(config, "device", Protocols::Http)?) + .headers(default_headers(config)?) + .body(format!( + r#"{{"id": "{}", "mac": "{}", "broadcast_addr": "{}", "ip": "{}"}}"#, + id, mac, broadcast_addr, ip + )) .send() .await .map_err(CliError::Reqwest)? diff --git a/src/requests/mod.rs b/src/requests/mod.rs deleted file mode 100644 index 6855db1..0000000 --- a/src/requests/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod start; -pub mod device; diff --git a/src/requests/start.rs b/src/requests/start.rs index ca4ca44..bc63303 100644 --- a/src/requests/start.rs +++ b/src/requests/start.rs @@ -1,25 +1,26 @@ -use futures_util::{StreamExt, SinkExt}; +use futures_util::{SinkExt, StreamExt}; use indicatif::{MultiProgress, ProgressBar}; use reqwest::StatusCode; use serde::Deserialize; use tokio_tungstenite::{connect_async, tungstenite::Message}; -use crate::{error::CliError, default_headers, ErrorResponse, format_url, Protocols, OVERVIEW_STYLE, DEFAULT_STYLE, DONE_STYLE, finish_pb, ERROR_STYLE, OVERVIEW_ERROR, OVERVIEW_DONE, add_pb}; - -pub async fn start(id: String, ping: bool) -> Result<(), CliError> { +use crate::{ + add_pb, config::Config, default_headers, error::CliError, finish_pb, format_url, ErrorResponse, + Protocols, DEFAULT_STYLE, DONE_STYLE, ERROR_STYLE, OVERVIEW_DONE, OVERVIEW_ERROR, + OVERVIEW_STYLE, +}; +pub async fn start(config: &Config, id: String, ping: bool) -> Result<(), CliError> { let send_start = MultiProgress::new(); let overview = add_pb(&send_start, OVERVIEW_STYLE, format!(") start {}", id)); // TODO: calculate average start-time on server - let url = format_url("start", Protocols::Http)?; + let url = format_url(config, "start", Protocols::Http)?; let connect = add_pb(&send_start, DEFAULT_STYLE, format!("connect to {}", url)); let res = reqwest::Client::new() .post(url) - .headers(default_headers()?) - .body( - format!(r#"{{"id": "{}", "ping": {}}}"#, id, ping) - ) + .headers(default_headers(config)?) + .body(format!(r#"{{"id": "{}", "ping": {}}}"#, id, ping)) .send() .await .map_err(CliError::Reqwest)?; @@ -29,7 +30,7 @@ pub async fn start(id: String, ping: bool) -> Result<(), CliError> { match res.status() { StatusCode::OK => { let body = serde_json::from_str::( - &res.text().await.map_err(CliError::Reqwest)? + &res.text().await.map_err(CliError::Reqwest)?, ) .map_err(CliError::Serde)?; @@ -38,17 +39,25 @@ pub async fn start(id: String, ping: bool) -> Result<(), CliError> { } if ping { - let status = status_socket(body.uuid, &send_start, &overview, id).await?; + let status = status_socket(config, body.uuid, &send_start, &overview, id).await?; if status { - finish_pb(overview, format!("successfully started {}", body.id), OVERVIEW_DONE); + finish_pb( + overview, + format!("successfully started {}", body.id), + OVERVIEW_DONE, + ); } else { - finish_pb(overview, format!("error while starting {}", body.id), OVERVIEW_ERROR); + finish_pb( + overview, + format!("error while starting {}", body.id), + OVERVIEW_ERROR, + ); } } - }, + } _ => { let body = serde_json::from_str::( - &res.text().await.map_err(CliError::Reqwest)? + &res.text().await.map_err(CliError::Reqwest)?, ) .map_err(CliError::Serde)?; @@ -59,16 +68,22 @@ pub async fn start(id: String, ping: bool) -> Result<(), CliError> { Ok(()) } -async fn status_socket(uuid: String, pb: &MultiProgress, overview: &ProgressBar, id: String) -> Result { - // TODO: Remove unwraps +async fn status_socket( + config: &Config, + uuid: String, + pb: &MultiProgress, + overview: &ProgressBar, + id: String, +) -> Result { let ws_pb = add_pb(pb, DEFAULT_STYLE, "connect to websocket".to_string()); - let (mut ws_stream, _response) = connect_async(format_url("status", Protocols::Websocket)?) - .await - .expect("Failed to connect"); + let (mut ws_stream, _response) = + connect_async(format_url(config, "status", Protocols::Websocket)?) + .await + .expect("Failed to connect"); finish_pb(ws_pb, "connected to websocket".to_string(), DONE_STYLE); - + ws_stream.send(Message::Text(uuid.clone())).await.unwrap(); - + // Get ETA let eta_msg = ws_stream.next().await.unwrap().unwrap(); let eta = get_eta(eta_msg.into_text().unwrap(), uuid.clone())? + overview.elapsed().as_secs(); @@ -86,29 +101,29 @@ async fn status_socket(uuid: String, pb: &MultiProgress, overview: &ProgressBar, Verified::WrongUuid => { finish_pb(v_pb, "returned wrong uuid".to_string(), ERROR_STYLE); Ok(false) - }, - Verified::ResponseType(res_type) => { - match res_type { - ResponseType::Start => { - finish_pb(v_pb, "device started".to_string(), DONE_STYLE); - Ok(true) - }, - ResponseType::Timeout => { - finish_pb(v_pb, "ping timed out".to_string(), ERROR_STYLE); - Ok(false) - }, - ResponseType::NotFound => { - finish_pb(v_pb, "unknown uuid".to_string(), ERROR_STYLE); - Ok(false) - }, - } } + Verified::ResponseType(res_type) => match res_type { + ResponseType::Start => { + finish_pb(v_pb, "device started".to_string(), DONE_STYLE); + Ok(true) + } + ResponseType::Timeout => { + finish_pb(v_pb, "ping timed out".to_string(), ERROR_STYLE); + Ok(false) + } + ResponseType::NotFound => { + finish_pb(v_pb, "unknown uuid".to_string(), ERROR_STYLE); + Ok(false) + } + }, } } fn get_eta(msg: String, uuid: String) -> Result { let spl: Vec<&str> = msg.split('_').collect(); - if (spl[0] != "eta") || (spl[2] != uuid) { return Err(CliError::WsResponse); }; + if (spl[0] != "eta") || (spl[2] != uuid) { + return Err(CliError::WsResponse); + }; Ok(u64::from_str_radix(spl[1], 10).map_err(CliError::Parse)?) } @@ -116,9 +131,11 @@ fn verify_response(res: String, org_uuid: String) -> Result let spl: Vec<&str> = res.split('_').collect(); let res_type = spl[0]; let uuid = spl[1]; - - if uuid != org_uuid { return Ok(Verified::WrongUuid) }; - + + if uuid != org_uuid { + return Ok(Verified::WrongUuid); + }; + Ok(Verified::ResponseType(ResponseType::from(res_type)?)) } @@ -131,7 +148,7 @@ struct StartResponse { enum Verified { ResponseType(ResponseType), - WrongUuid + WrongUuid, } enum ResponseType { -- cgit v1.2.3