From a192e9baca9a14beaa9f87c27a63cff96aa41c94 Mon Sep 17 00:00:00 2001 From: FxQnLr Date: Sun, 25 Feb 2024 20:00:38 +0100 Subject: Closes #4. Auth on Websocket. Small stuff --- Cargo.lock | 7 +++++++ Cargo.toml | 1 + src/error.rs | 13 +++++++++++-- src/main.rs | 33 +++++++++++++++++++-------------- src/requests/device.rs | 23 ++++++++++------------- src/requests/start.rs | 24 ++++++++++++++++++------ 6 files changed, 66 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c0f07f7..25f23fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,6 +65,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "anyhow" +version = "1.0.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" + [[package]] name = "async-trait" version = "0.1.74" @@ -1725,6 +1731,7 @@ dependencies = [ name = "webol-cli" version = "0.2.0" dependencies = [ + "anyhow", "clap", "clap_complete", "config", diff --git a/Cargo.toml b/Cargo.toml index a60d788..4791c6c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ name = "webol" path = "src/main.rs" [dependencies] +anyhow = "1.0" clap = { version = "4.5", features = ["derive"] } clap_complete = "4.5" config = "0.14" diff --git a/src/error.rs b/src/error.rs index 15e4308..1e6eac1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -27,8 +27,17 @@ pub enum Error { #[error("parse header: {source}")] InvalidHeaderValue { #[from] - source: InvalidHeaderValue + source: InvalidHeaderValue, }, - #[error("ws")] + #[error("tungstenite: {source}")] + Tungstenite { + #[from] + source: tokio_tungstenite::tungstenite::Error, + }, + #[error("faulty websocket response")] WsResponse, + #[error("authorization failed")] + Authorization, + #[error("Http error status: {0}")] + HttpStatus(u16), } diff --git a/src/main.rs b/src/main.rs index cdca6cb..d76341f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,10 @@ use clap_complete::{generate, Generator, Shell}; use error::Error; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use requests::{device, start::start}; -use reqwest::header::{HeaderMap, HeaderValue}; +use reqwest::{ + header::{HeaderMap, HeaderValue}, + Response, +}; use serde::Deserialize; mod config; @@ -66,7 +69,7 @@ enum DeviceCmd { } #[tokio::main] -async fn main() -> Result<(), Error> { +async fn main() -> Result<(), anyhow::Error> { let config = Config::load()?; let cli = Args::parse(); @@ -112,18 +115,9 @@ fn print_completions(gen: G, cmd: &mut Command) { fn default_headers(config: &Config) -> Result { let mut map = HeaderMap::new(); - map.append( - "Accept-Content", - HeaderValue::from_str("application/json")? - ); - map.append( - "Content-Type", - HeaderValue::from_str("application/json")? - ); - map.append( - "Authorization", - HeaderValue::from_str(&config.apikey)? - ); + map.append("Accept-Content", HeaderValue::from_str("application/json")?); + map.append("Content-Type", HeaderValue::from_str("application/json")?); + map.append("Authorization", HeaderValue::from_str(&config.apikey)?); Ok(map) } @@ -132,6 +126,17 @@ fn format_url(config: &Config, path: &str, protocol: &Protocols) -> String { format!("{}://{}/{}", protocol, config.server, path) } +async fn check_success(res: Response) -> Result { + let status = res.status(); + if status.is_success() { + Ok(res.text().await?) + } else if status.as_u16() == 401 { + Err(Error::Authorization) + } else { + Err(Error::HttpStatus(status.as_u16())) + } +} + fn add_pb(mp: &MultiProgress, template: &str, message: String) -> ProgressBar { let pb = mp.add(ProgressBar::new(1)); pb.set_style(ProgressStyle::with_template(template).unwrap()); diff --git a/src/requests/device.rs b/src/requests/device.rs index a612978..7583406 100644 --- a/src/requests/device.rs +++ b/src/requests/device.rs @@ -1,4 +1,4 @@ -use crate::{config::Config, default_headers, error::Error, format_url, Protocols}; +use crate::{check_success, config::Config, default_headers, error::Error, format_url, Protocols}; pub async fn put( config: &Config, @@ -16,11 +16,10 @@ pub async fn put( r#"{{"id": "{id}", "mac": "{mac}", "broadcast_addr": "{broadcast_addr}", "ip": "{ip}"}}"#, )) .send() - .await? - .text() - .await; + .await?; - println!("{res:?}"); + let body = check_success(res).await?; + println!("{body}"); Ok(()) } @@ -30,11 +29,10 @@ pub async fn get(config: &Config, id: String) -> Result<(), Error> { .headers(default_headers(config)?) .body(format!(r#"{{"id": "{id}"}}"#)) .send() - .await? - .text() - .await; + .await?; - println!("{res:?}"); + let body = check_success(res).await?; + println!("{body}"); Ok(()) } @@ -52,10 +50,9 @@ pub async fn post( r#"{{"id": "{id}", "mac": "{mac}", "broadcast_addr": "{broadcast_addr}", "ip": "{ip}"}}"#, )) .send() - .await? - .text() - .await; + .await?; - println!("{res:?}"); + let body = check_success(res).await?; + println!("{body}"); Ok(()) } diff --git a/src/requests/start.rs b/src/requests/start.rs index 7abbbe0..d07177e 100644 --- a/src/requests/start.rs +++ b/src/requests/start.rs @@ -2,7 +2,10 @@ use futures_util::{SinkExt, StreamExt}; use indicatif::{MultiProgress, ProgressBar}; use reqwest::StatusCode; use serde::Deserialize; -use tokio_tungstenite::{connect_async, tungstenite::Message}; +use tokio_tungstenite::{ + connect_async, + tungstenite::{http::Request, Message}, +}; use crate::{ add_pb, config::Config, default_headers, error::Error, finish_pb, format_url, ErrorResponse, @@ -66,17 +69,26 @@ async fn status_socket( id: String, ) -> Result { let ws_pb = add_pb(pb, DEFAULT_STYLE, "connect to websocket".to_string()); - let (mut ws_stream, _response) = - connect_async(format_url(config, "status", &Protocols::Websocket)) - .await - .expect("Failed to connect"); + + let request = Request::builder() + .uri(format_url(config, "status", &Protocols::Websocket)) + .header("Authorization", &config.apikey) + .header("sec-websocket-key", "") + .header("host", &config.server) + .header("upgrade", "websocket") + .header("connection", "upgrade") + .header("sec-websocket-version", 13) + .body(()) + .unwrap(); + + let (mut ws_stream, _response) = connect_async(request).await?; finish_pb(&ws_pb, "connected to websocket".to_string(), DONE_STYLE); ws_stream.send(Message::Text(uuid.clone())).await.unwrap(); // Get ETA let eta_msg = ws_stream.next().await.unwrap().unwrap(); - let eta = get_eta(&eta_msg.into_text().unwrap(), &uuid)? + overview.elapsed().as_secs(); + let eta = get_eta(&eta_msg.into_text().unwrap(), &uuid)?; overview.set_message(format!("/{eta}) start {id}")); let msg_pb = add_pb(pb, DEFAULT_STYLE, "await message".to_string()); -- cgit v1.2.3