From 34272176867af1d566bd3bc561486a4e7a492ea9 Mon Sep 17 00:00:00 2001
From: Mathijs van Veluw <black.dex@gmail.com>
Date: Sun, 17 Mar 2024 19:52:55 +0100
Subject: [PATCH] Remove custom WebSocket code (#4001)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* Remove custom WebSocket code

Remove our custom WebSocket code and only use the Rocket code.
Removed all options in regards to WebSockets
Added a new option `WEBSOCKET_DISABLED` which defaults too `false`.
This can be used to disable WebSockets if you really do not want to use it.

* Addressed remarks given and some updates

- Addressed comments given during review
- Updated crates, including Rocket to the latest merged v0.5 changes
- Removed an extra header which should not be sent for websocket connections

* Updated suggestions and crates

- Addressed the suggestions
- Updated Rocket to latest rc4
  Also made the needed code changes
- Updated all other crates
  Pinned `openssl` and `openssl-sys`

---------

Co-authored-by: Daniel GarcĂ­a <dani-garcia@users.noreply.github.com>
---
 .env.template            |   8 +-
 Cargo.lock               |   1 -
 Cargo.toml               |   1 -
 src/api/mod.rs           |   2 +-
 src/api/notifications.rs | 228 ++++++++++++---------------------------
 src/config.rs            |  18 +---
 src/error.rs             |   2 -
 src/main.rs              |   4 +-
 8 files changed, 77 insertions(+), 187 deletions(-)

diff --git a/.env.template b/.env.template
index 46aa6271..61cd046b 100644
--- a/.env.template
+++ b/.env.template
@@ -84,12 +84,8 @@
 ### WebSocket ###
 #################
 
-## Enables websocket notifications
-# WEBSOCKET_ENABLED=false
-
-## Controls the WebSocket server address and port
-# WEBSOCKET_ADDRESS=0.0.0.0
-# WEBSOCKET_PORT=3012
+## Enable websocket notifications
+# ENABLE_WEBSOCKET=true
 
 ##########################
 ### Push notifications ###
diff --git a/Cargo.lock b/Cargo.lock
index 86f0f234..b83eb071 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3784,7 +3784,6 @@ dependencies = [
  "syslog",
  "time",
  "tokio",
- "tokio-tungstenite",
  "totp-lite",
  "tracing",
  "url",
diff --git a/Cargo.toml b/Cargo.toml
index e5a3edd9..26916626 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -60,7 +60,6 @@ rocket = { version = "0.5.0", features = ["tls", "json"], default-features = fal
 rocket_ws = { version ="0.1.0" }
 
 # WebSockets libraries
-tokio-tungstenite = "0.20.1"
 rmpv = "1.0.1" # MessagePack library
 
 # Concurrent HashMap used for WebSocket messaging and favicons
diff --git a/src/api/mod.rs b/src/api/mod.rs
index 99915bdf..c6838aaa 100644
--- a/src/api/mod.rs
+++ b/src/api/mod.rs
@@ -23,7 +23,7 @@ pub use crate::api::{
     icons::routes as icons_routes,
     identity::routes as identity_routes,
     notifications::routes as notifications_routes,
-    notifications::{start_notification_server, AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS},
+    notifications::{AnonymousNotify, Notify, UpdateType, WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS},
     push::{
         push_cipher_update, push_folder_update, push_logout, push_send_update, push_user_update, register_push_device,
         unregister_push_device,
diff --git a/src/api/notifications.rs b/src/api/notifications.rs
index da2664cf..1f64b86e 100644
--- a/src/api/notifications.rs
+++ b/src/api/notifications.rs
@@ -1,23 +1,11 @@
-use std::{
-    net::{IpAddr, SocketAddr},
-    sync::Arc,
-    time::Duration,
-};
+use std::{net::IpAddr, sync::Arc, time::Duration};
 
 use chrono::{NaiveDateTime, Utc};
 use rmpv::Value;
-use rocket::{
-    futures::{SinkExt, StreamExt},
-    Route,
-};
-use tokio::{
-    net::{TcpListener, TcpStream},
-    sync::mpsc::Sender,
-};
-use tokio_tungstenite::{
-    accept_hdr_async,
-    tungstenite::{handshake, Message},
-};
+use rocket::{futures::StreamExt, Route};
+use tokio::sync::mpsc::Sender;
+
+use rocket_ws::{Message, WebSocket};
 
 use crate::{
     auth::{ClientIp, WsAccessTokenHeader},
@@ -30,7 +18,7 @@ use crate::{
 
 use once_cell::sync::Lazy;
 
-static WS_USERS: Lazy<Arc<WebSocketUsers>> = Lazy::new(|| {
+pub static WS_USERS: Lazy<Arc<WebSocketUsers>> = Lazy::new(|| {
     Arc::new(WebSocketUsers {
         map: Arc::new(dashmap::DashMap::new()),
     })
@@ -47,8 +35,15 @@ use super::{
     push_send_update, push_user_update,
 };
 
+static NOTIFICATIONS_DISABLED: Lazy<bool> = Lazy::new(|| !CONFIG.enable_websocket() && !CONFIG.push_enabled());
+
 pub fn routes() -> Vec<Route> {
-    routes![websockets_hub, anonymous_websockets_hub]
+    if CONFIG.enable_websocket() {
+        routes![websockets_hub, anonymous_websockets_hub]
+    } else {
+        info!("WebSocket are disabled, realtime sync functionality will not work!");
+        routes![]
+    }
 }
 
 #[derive(FromForm, Debug)]
@@ -108,7 +103,7 @@ impl Drop for WSAnonymousEntryMapGuard {
 
 #[get("/hub?<data..>")]
 fn websockets_hub<'r>(
-    ws: rocket_ws::WebSocket,
+    ws: WebSocket,
     data: WsAccessToken,
     ip: ClientIp,
     header_token: WsAccessTokenHeader,
@@ -192,11 +187,7 @@ fn websockets_hub<'r>(
 }
 
 #[get("/anonymous-hub?<token..>")]
-fn anonymous_websockets_hub<'r>(
-    ws: rocket_ws::WebSocket,
-    token: String,
-    ip: ClientIp,
-) -> Result<rocket_ws::Stream!['r], Error> {
+fn anonymous_websockets_hub<'r>(ws: WebSocket, token: String, ip: ClientIp) -> Result<rocket_ws::Stream!['r], Error> {
     let addr = ip.ip;
     info!("Accepting Anonymous Rocket WS connection from {addr}");
 
@@ -349,13 +340,19 @@ impl WebSocketUsers {
 
     // NOTE: The last modified date needs to be updated before calling these methods
     pub async fn send_user_update(&self, ut: UpdateType, user: &User) {
+        // Skip any processing if both WebSockets and Push are not active
+        if *NOTIFICATIONS_DISABLED {
+            return;
+        }
         let data = create_update(
             vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))],
             ut,
             None,
         );
 
-        self.send_update(&user.uuid, &data).await;
+        if CONFIG.enable_websocket() {
+            self.send_update(&user.uuid, &data).await;
+        }
 
         if CONFIG.push_enabled() {
             push_user_update(ut, user);
@@ -363,13 +360,19 @@ impl WebSocketUsers {
     }
 
     pub async fn send_logout(&self, user: &User, acting_device_uuid: Option<String>) {
+        // Skip any processing if both WebSockets and Push are not active
+        if *NOTIFICATIONS_DISABLED {
+            return;
+        }
         let data = create_update(
             vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))],
             UpdateType::LogOut,
             acting_device_uuid.clone(),
         );
 
-        self.send_update(&user.uuid, &data).await;
+        if CONFIG.enable_websocket() {
+            self.send_update(&user.uuid, &data).await;
+        }
 
         if CONFIG.push_enabled() {
             push_logout(user, acting_device_uuid);
@@ -383,6 +386,10 @@ impl WebSocketUsers {
         acting_device_uuid: &String,
         conn: &mut DbConn,
     ) {
+        // Skip any processing if both WebSockets and Push are not active
+        if *NOTIFICATIONS_DISABLED {
+            return;
+        }
         let data = create_update(
             vec![
                 ("Id".into(), folder.uuid.clone().into()),
@@ -393,7 +400,9 @@ impl WebSocketUsers {
             Some(acting_device_uuid.into()),
         );
 
-        self.send_update(&folder.user_uuid, &data).await;
+        if CONFIG.enable_websocket() {
+            self.send_update(&folder.user_uuid, &data).await;
+        }
 
         if CONFIG.push_enabled() {
             push_folder_update(ut, folder, acting_device_uuid, conn).await;
@@ -409,6 +418,10 @@ impl WebSocketUsers {
         collection_uuids: Option<Vec<String>>,
         conn: &mut DbConn,
     ) {
+        // Skip any processing if both WebSockets and Push are not active
+        if *NOTIFICATIONS_DISABLED {
+            return;
+        }
         let org_uuid = convert_option(cipher.organization_uuid.clone());
         // Depending if there are collections provided or not, we need to have different values for the following variables.
         // The user_uuid should be `null`, and the revision date should be set to now, else the clients won't sync the collection change.
@@ -434,8 +447,10 @@ impl WebSocketUsers {
             Some(acting_device_uuid.into()),
         );
 
-        for uuid in user_uuids {
-            self.send_update(uuid, &data).await;
+        if CONFIG.enable_websocket() {
+            for uuid in user_uuids {
+                self.send_update(uuid, &data).await;
+            }
         }
 
         if CONFIG.push_enabled() && user_uuids.len() == 1 {
@@ -451,6 +466,10 @@ impl WebSocketUsers {
         acting_device_uuid: &String,
         conn: &mut DbConn,
     ) {
+        // Skip any processing if both WebSockets and Push are not active
+        if *NOTIFICATIONS_DISABLED {
+            return;
+        }
         let user_uuid = convert_option(send.user_uuid.clone());
 
         let data = create_update(
@@ -463,8 +482,10 @@ impl WebSocketUsers {
             None,
         );
 
-        for uuid in user_uuids {
-            self.send_update(uuid, &data).await;
+        if CONFIG.enable_websocket() {
+            for uuid in user_uuids {
+                self.send_update(uuid, &data).await;
+            }
         }
         if CONFIG.push_enabled() && user_uuids.len() == 1 {
             push_send_update(ut, send, acting_device_uuid, conn).await;
@@ -478,12 +499,18 @@ impl WebSocketUsers {
         acting_device_uuid: &String,
         conn: &mut DbConn,
     ) {
+        // Skip any processing if both WebSockets and Push are not active
+        if *NOTIFICATIONS_DISABLED {
+            return;
+        }
         let data = create_update(
             vec![("Id".into(), auth_request_uuid.clone().into()), ("UserId".into(), user_uuid.clone().into())],
             UpdateType::AuthRequest,
             Some(acting_device_uuid.to_string()),
         );
-        self.send_update(user_uuid, &data).await;
+        if CONFIG.enable_websocket() {
+            self.send_update(user_uuid, &data).await;
+        }
 
         if CONFIG.push_enabled() {
             push_auth_request(user_uuid.to_string(), auth_request_uuid.to_string(), conn).await;
@@ -497,12 +524,18 @@ impl WebSocketUsers {
         approving_device_uuid: String,
         conn: &mut DbConn,
     ) {
+        // Skip any processing if both WebSockets and Push are not active
+        if *NOTIFICATIONS_DISABLED {
+            return;
+        }
         let data = create_update(
             vec![("Id".into(), auth_response_uuid.to_owned().into()), ("UserId".into(), user_uuid.clone().into())],
             UpdateType::AuthRequestResponse,
             approving_device_uuid.clone().into(),
         );
-        self.send_update(auth_response_uuid, &data).await;
+        if CONFIG.enable_websocket() {
+            self.send_update(auth_response_uuid, &data).await;
+        }
 
         if CONFIG.push_enabled() {
             push_auth_response(user_uuid.to_string(), auth_response_uuid.to_string(), approving_device_uuid, conn)
@@ -526,6 +559,9 @@ impl AnonymousWebSocketSubscriptions {
     }
 
     pub async fn send_auth_response(&self, user_uuid: &String, auth_response_uuid: &str) {
+        if !CONFIG.enable_websocket() {
+            return;
+        }
         let data = create_anonymous_update(
             vec![("Id".into(), auth_response_uuid.to_owned().into()), ("UserId".into(), user_uuid.clone().into())],
             UpdateType::AuthRequestResponse,
@@ -620,127 +656,3 @@ pub enum UpdateType {
 
 pub type Notify<'a> = &'a rocket::State<Arc<WebSocketUsers>>;
 pub type AnonymousNotify<'a> = &'a rocket::State<Arc<AnonymousWebSocketSubscriptions>>;
-
-pub fn start_notification_server() -> Arc<WebSocketUsers> {
-    let users = Arc::clone(&WS_USERS);
-    if CONFIG.websocket_enabled() {
-        let users2 = Arc::<WebSocketUsers>::clone(&users);
-        tokio::spawn(async move {
-            let addr = (CONFIG.websocket_address(), CONFIG.websocket_port());
-            info!("Starting WebSockets server on {}:{}", addr.0, addr.1);
-            let listener = TcpListener::bind(addr).await.expect("Can't listen on websocket port");
-
-            let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
-            CONFIG.set_ws_shutdown_handle(shutdown_tx);
-
-            loop {
-                tokio::select! {
-                    Ok((stream, addr)) = listener.accept() => {
-                        tokio::spawn(handle_connection(stream, Arc::<WebSocketUsers>::clone(&users2), addr));
-                    }
-
-                    _ = &mut shutdown_rx => {
-                        break;
-                    }
-                }
-            }
-
-            info!("Shutting down WebSockets server!")
-        });
-    }
-
-    users
-}
-
-async fn handle_connection(stream: TcpStream, users: Arc<WebSocketUsers>, addr: SocketAddr) -> Result<(), Error> {
-    let mut user_uuid: Option<String> = None;
-
-    info!("Accepting WS connection from {addr}");
-
-    // Accept connection, do initial handshake, validate auth token and get the user ID
-    use handshake::server::{Request, Response};
-    let mut stream = accept_hdr_async(stream, |req: &Request, res: Response| {
-        if let Some(token) = get_request_token(req) {
-            if let Ok(claims) = crate::auth::decode_login(&token) {
-                user_uuid = Some(claims.sub);
-                return Ok(res);
-            }
-        }
-        Err(Response::builder().status(401).body(None).unwrap())
-    })
-    .await?;
-
-    let user_uuid = user_uuid.expect("User UUID should be set after the handshake");
-
-    let (mut rx, guard) = {
-        // Add a channel to send messages to this client to the map
-        let entry_uuid = uuid::Uuid::new_v4();
-        let (tx, rx) = tokio::sync::mpsc::channel::<Message>(100);
-        users.map.entry(user_uuid.clone()).or_default().push((entry_uuid, tx));
-
-        // Once the guard goes out of scope, the connection will have been closed and the entry will be deleted from the map
-        (rx, WSEntryMapGuard::new(users, user_uuid, entry_uuid, addr.ip()))
-    };
-
-    let _guard = guard;
-    let mut interval = tokio::time::interval(Duration::from_secs(15));
-    loop {
-        tokio::select! {
-            res = stream.next() =>  {
-                match res {
-                    Some(Ok(message)) => {
-                        match message {
-                            // Respond to any pings
-                            Message::Ping(ping) => stream.send(Message::Pong(ping)).await?,
-                            Message::Pong(_) => {/* Ignored */},
-
-                            // We should receive an initial message with the protocol and version, and we will reply to it
-                            Message::Text(ref message) => {
-                                let msg = message.strip_suffix(RECORD_SEPARATOR as char).unwrap_or(message);
-
-                                if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
-                                    stream.send(Message::binary(INITIAL_RESPONSE)).await?;
-                                    continue;
-                                }
-                            }
-                            // Just echo anything else the client sends
-                            _ => stream.send(message).await?,
-                        }
-                    }
-                    _ => break,
-                }
-            }
-
-            res = rx.recv() => {
-                match res {
-                    Some(res) => stream.send(res).await?,
-                    None => break,
-                }
-            }
-
-            _ = interval.tick() => stream.send(Message::Ping(create_ping())).await?
-        }
-    }
-
-    Ok(())
-}
-
-fn get_request_token(req: &handshake::server::Request) -> Option<String> {
-    const ACCESS_TOKEN_KEY: &str = "access_token=";
-
-    if let Some(Ok(auth)) = req.headers().get("Authorization").map(|a| a.to_str()) {
-        if let Some(token_part) = auth.strip_prefix("Bearer ") {
-            return Some(token_part.to_owned());
-        }
-    }
-
-    if let Some(params) = req.uri().query() {
-        let params_iter = params.split('&').take(1);
-        for val in params_iter {
-            if let Some(stripped) = val.strip_prefix(ACCESS_TOKEN_KEY) {
-                return Some(stripped.to_owned());
-            }
-        }
-    }
-    None
-}
diff --git a/src/config.rs b/src/config.rs
index e174c66b..01f387ec 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -39,7 +39,6 @@ macro_rules! make_config {
 
         struct Inner {
             rocket_shutdown_handle: Option<rocket::Shutdown>,
-            ws_shutdown_handle: Option<tokio::sync::oneshot::Sender<()>>,
 
             templates: Handlebars<'static>,
             config: ConfigItems,
@@ -361,7 +360,7 @@ make_config! {
         /// Sends folder
         sends_folder:           String, false,  auto,   |c| format!("{}/{}", c.data_folder, "sends");
         /// Temp folder |> Used for storing temporary file uploads
-        tmp_folder:           String, false,  auto,   |c| format!("{}/{}", c.data_folder, "tmp");
+        tmp_folder:             String, false,  auto,   |c| format!("{}/{}", c.data_folder, "tmp");
         /// Templates folder
         templates_folder:       String, false,  auto,   |c| format!("{}/{}", c.data_folder, "templates");
         /// Session JWT key
@@ -371,11 +370,7 @@ make_config! {
     },
     ws {
         /// Enable websocket notifications
-        websocket_enabled:      bool,   false,  def,    false;
-        /// Websocket address
-        websocket_address:      String, false,  def,    "0.0.0.0".to_string();
-        /// Websocket port
-        websocket_port:         u16,    false,  def,    3012;
+        enable_websocket:       bool,   false,  def,    true;
     },
     push {
         /// Enable push notifications
@@ -1071,7 +1066,6 @@ impl Config {
         Ok(Config {
             inner: RwLock::new(Inner {
                 rocket_shutdown_handle: None,
-                ws_shutdown_handle: None,
                 templates: load_templates(&config.templates_folder),
                 config,
                 _env,
@@ -1237,16 +1231,8 @@ impl Config {
         self.inner.write().unwrap().rocket_shutdown_handle = Some(handle);
     }
 
-    pub fn set_ws_shutdown_handle(&self, handle: tokio::sync::oneshot::Sender<()>) {
-        self.inner.write().unwrap().ws_shutdown_handle = Some(handle);
-    }
-
     pub fn shutdown(&self) {
         if let Ok(mut c) = self.inner.write() {
-            if let Some(handle) = c.ws_shutdown_handle.take() {
-                handle.send(()).ok();
-            }
-
             if let Some(handle) = c.rocket_shutdown_handle.take() {
                 handle.notify();
             }
diff --git a/src/error.rs b/src/error.rs
index f0969bff..784aad6a 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -52,7 +52,6 @@ use rocket::error::Error as RocketErr;
 use serde_json::{Error as SerdeErr, Value};
 use std::io::Error as IoErr;
 use std::time::SystemTimeError as TimeErr;
-use tokio_tungstenite::tungstenite::Error as TungstError;
 use webauthn_rs::error::WebauthnError as WebauthnErr;
 use yubico::yubicoerror::YubicoError as YubiErr;
 
@@ -91,7 +90,6 @@ make_error! {
 
     DieselCon(DieselConErr): _has_source, _api_error,
     Webauthn(WebauthnErr):   _has_source, _api_error,
-    WebSocket(TungstError):  _has_source, _api_error,
 }
 
 impl std::fmt::Debug for Error {
diff --git a/src/main.rs b/src/main.rs
index 53b72606..285dc33a 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -52,7 +52,7 @@ mod ratelimit;
 mod util;
 
 use crate::api::purge_auth_requests;
-use crate::api::WS_ANONYMOUS_SUBSCRIPTIONS;
+use crate::api::{WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS};
 pub use config::CONFIG;
 pub use error::{Error, MapResult};
 use rocket::data::{Limits, ToByteUnit};
@@ -497,7 +497,7 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
         .register([basepath, "/api"].concat(), api::core_catchers())
         .register([basepath, "/admin"].concat(), api::admin_catchers())
         .manage(pool)
-        .manage(api::start_notification_server())
+        .manage(Arc::clone(&WS_USERS))
         .manage(Arc::clone(&WS_ANONYMOUS_SUBSCRIPTIONS))
         .attach(util::AppHeaders())
         .attach(util::Cors())