From cb889193c7577393ca526215ae6ce48a3c3b85a5 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Fri, 10 May 2024 15:12:33 +0200 Subject: [PATCH] wip --- .gitignore | 2 +- crates/lavina-core/src/clustering.rs | 84 ++----------------- .../lavina-core/src/clustering/broadcast.rs | 3 + crates/lavina-core/src/clustering/room.rs | 59 +++++++++++++ crates/lavina-core/src/player.rs | 49 +++++++++-- crates/lavina-core/src/repo/room.rs | 15 ++++ crates/lavina-core/src/repo/user.rs | 15 ++++ crates/lavina-core/src/room.rs | 2 +- src/http.rs | 74 +--------------- src/http/clustering.rs | 78 +++++++++++++++++ 10 files changed, 225 insertions(+), 156 deletions(-) create mode 100644 crates/lavina-core/src/clustering/room.rs create mode 100644 src/http/clustering.rs diff --git a/.gitignore b/.gitignore index 75d301e..2ee75f0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ /target -/db.sqlite +*.sqlite .idea/ .DS_Store diff --git a/crates/lavina-core/src/clustering.rs b/crates/lavina-core/src/clustering.rs index a5468cb..20ad30d 100644 --- a/crates/lavina-core/src/clustering.rs +++ b/crates/lavina-core/src/clustering.rs @@ -7,6 +7,7 @@ use std::net::SocketAddr; use std::sync::Arc; pub mod broadcast; +pub mod room; type Addresses = Vec; @@ -33,35 +34,6 @@ pub struct LavinaClient { client: ClientWithMiddleware, } -#[derive(Serialize, Deserialize, Debug)] -pub struct SendMessageReq<'a> { - pub room_id: &'a str, - pub player_id: &'a str, - pub message: &'a str, - pub created_at: &'a str, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct BroadcastMessageReq<'a> { - pub room_id: &'a str, - pub author_id: &'a str, - pub message: &'a str, - pub created_at: &'a str, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct SetRoomTopicReq<'a> { - pub room_id: &'a str, - pub player_id: &'a str, - pub topic: &'a str, -} - -pub mod paths { - pub const ADD_MESSAGE: &'static str = "/cluster/rooms/add_message"; - pub const BROADCAST_MESSAGE: &'static str = "/cluster/rooms/broadcast_message"; - pub const SET_TOPIC: &'static str = "/cluster/rooms/set_topic"; -} - impl LavinaClient { pub fn new(addresses: Addresses) -> Self { let client = ClientBuilder::new(Client::new()).with(TracingMiddleware::::new()).build(); @@ -71,59 +43,13 @@ impl LavinaClient { } } - #[tracing::instrument(skip(self, req), name = "LavinaClient::send_room_message")] - pub async fn send_room_message(&self, node_id: u32, req: SendMessageReq<'_>) -> Result<()> { - tracing::info!("Sending a message to a room on a remote node"); + async fn send_request(&self, node_id: u32, path: &str, req: impl Serialize) -> Result<()> { let Some(address) = self.addresses.get(node_id as usize) else { - tracing::error!("Failed"); return Err(anyhow!("Unknown node")); }; - match self.client.post(format!("http://{}{}", address, paths::BROADCAST_MESSAGE)).json(&req).send().await { - Ok(_) => { - tracing::info!("Message sent"); - Ok(()) - } - Err(e) => { - tracing::error!("Failed to send message: {e:?}"); - Err(e.into()) - } - } - } - - #[tracing::instrument(skip(self, req), name = "LavinaClient::broadcast_room_message")] - pub async fn broadcast_room_message(&self, node_id: u32, req: BroadcastMessageReq<'_>) -> Result<()> { - tracing::info!("Broadcasting a message to a room on a remote node"); - let Some(address) = self.addresses.get(node_id as usize) else { - tracing::error!("Failed"); - return Err(anyhow!("Unknown node")); - }; - match self.client.post(format!("http://{}{}", address, paths::BROADCAST_MESSAGE)).json(&req).send().await { - Ok(_) => { - tracing::info!("Message broadcasted"); - Ok(()) - } - Err(e) => { - tracing::error!("Failed to broadcast message: {e:?}"); - Err(e.into()) - } - } - } - - pub async fn set_room_topic(&self, node_id: u32, req: SetRoomTopicReq<'_>) -> Result<()> { - tracing::info!("Setting the topic of a room on a remote node"); - let Some(address) = self.addresses.get(node_id as usize) else { - tracing::error!("Failed"); - return Err(anyhow!("Unknown node")); - }; - match self.client.post(format!("http://{}{}", address, paths::SET_TOPIC)).json(&req).send().await { - Ok(_) => { - tracing::info!("Room topic set"); - Ok(()) - } - Err(e) => { - tracing::error!("Failed to set room topic: {e:?}"); - Err(e.into()) - } + match self.client.post(format!("http://{}{}", address, path)).json(&req).send().await { + Ok(_) => Ok(()), + Err(e) => Err(e.into()), } } } diff --git a/crates/lavina-core/src/clustering/broadcast.rs b/crates/lavina-core/src/clustering/broadcast.rs index 456f820..f42754a 100644 --- a/crates/lavina-core/src/clustering/broadcast.rs +++ b/crates/lavina-core/src/clustering/broadcast.rs @@ -45,6 +45,9 @@ impl Broadcasting { created_at: created_at.clone(), }; for i in subscribers { + if i == &author_id { + continue; + } let Some(player) = players.get_player(i).await else { continue; }; diff --git a/crates/lavina-core/src/clustering/room.rs b/crates/lavina-core/src/clustering/room.rs new file mode 100644 index 0000000..dff5fa3 --- /dev/null +++ b/crates/lavina-core/src/clustering/room.rs @@ -0,0 +1,59 @@ +use serde::{Deserialize, Serialize}; + +use crate::clustering::LavinaClient; + +pub mod paths { + pub const JOIN: &'static str = "/cluster/rooms/join"; + pub const LEAVE: &'static str = "/cluster/rooms/leave"; + pub const ADD_MESSAGE: &'static str = "/cluster/rooms/add_message"; + pub const SET_TOPIC: &'static str = "/cluster/rooms/set_topic"; +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct JoinRoomReq<'a> { + pub room_id: &'a str, + pub player_id: &'a str, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct LeaveRoomReq<'a> { + pub room_id: &'a str, + pub player_id: &'a str, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SendMessageReq<'a> { + pub room_id: &'a str, + pub player_id: &'a str, + pub message: &'a str, + pub created_at: &'a str, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SetRoomTopicReq<'a> { + pub room_id: &'a str, + pub player_id: &'a str, + pub topic: &'a str, +} + +impl LavinaClient { + #[tracing::instrument(skip(self, req), name = "LavinaClient::join_room")] + pub async fn join_room(&self, node_id: u32, req: JoinRoomReq<'_>) -> anyhow::Result<()> { + self.send_request(node_id, paths::JOIN, req).await + } + + #[tracing::instrument(skip(self, req), name = "LavinaClient::leave_room")] + pub async fn leave_room(&self, node_id: u32, req: LeaveRoomReq<'_>) -> anyhow::Result<()> { + self.send_request(node_id, paths::LEAVE, req).await + } + + #[tracing::instrument(skip(self, req), name = "LavinaClient::send_room_message")] + pub async fn send_room_message(&self, node_id: u32, req: SendMessageReq<'_>) -> anyhow::Result<()> { + self.send_request(node_id, paths::ADD_MESSAGE, req).await + } + + #[tracing::instrument(skip(self, req), name = "LavinaClient::set_room_topic")] + pub async fn set_room_topic(&self, node_id: u32, req: SetRoomTopicReq<'_>) -> anyhow::Result<()> { + self.send_request(node_id, paths::SET_TOPIC, req).await + } +} diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 07fdfd5..9387991 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -19,7 +19,8 @@ use tokio::sync::RwLock; use tracing::{Instrument, Span}; use crate::clustering::broadcast::Broadcasting; -use crate::clustering::{ClusterMetadata, LavinaClient, SendMessageReq, SetRoomTopicReq}; +use crate::clustering::room::*; +use crate::clustering::{ClusterMetadata, LavinaClient}; use crate::dialog::DialogRegistry; use crate::prelude::*; use crate::repo::Storage; @@ -336,6 +337,7 @@ impl PlayerRegistry { } else { let (handle, fiber) = Player::launch( id.clone(), + self.clone(), inner.room_registry.clone(), inner.dialogs.clone(), inner.cluster_metadata.clone(), @@ -397,6 +399,7 @@ struct Player { banned_from: HashSet, rx: Receiver<(ActorCommand, Span)>, handle: PlayerHandle, + players: PlayerRegistry, rooms: RoomRegistry, dialogs: DialogRegistry, storage: Storage, @@ -407,6 +410,7 @@ struct Player { impl Player { async fn launch( player_id: PlayerId, + players: PlayerRegistry, rooms: RoomRegistry, dialogs: DialogRegistry, cluster_metadata: Arc, @@ -429,6 +433,7 @@ impl Player { banned_from: HashSet::new(), rx, handle, + players, rooms, dialogs, storage, @@ -582,7 +587,19 @@ impl Player { } if let Some(remote_node) = self.room_location(&room_id) { - todo!() + let req = JoinRoomReq { + room_id: room_id.as_inner(), + player_id: self.player_id.as_inner(), + }; + self.cluster_client.join_room(remote_node, req).await.unwrap(); + let room_storage_id = self.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await.unwrap(); + self.storage.add_room_member(room_storage_id, self.storage_id).await.unwrap(); + self.my_rooms.insert(room_id.clone(), RoomRef::Remote { node_id: remote_node }); + JoinResult::Success(RoomInfo { + id: room_id, + topic: "unknown".into(), + members: vec![], + }) } else { let room = match self.rooms.get_or_create_room(room_id.clone()).await { Ok(room) => room, @@ -608,9 +625,22 @@ impl Player { async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) { let room = self.my_rooms.remove(&room_id); if let Some(room) = room { - panic!(); - // room.unsubscribe(&self.player_id).await; - // room.remove_member(&self.player_id, self.storage_id).await; + match room { + RoomRef::Local(room) => { + room.unsubscribe(&self.player_id).await; + room.remove_member(&self.player_id, self.storage_id).await; + let room_storage_id = + self.storage.create_or_retrieve_room_id_by_name(room_id.as_inner()).await.unwrap(); + self.storage.remove_room_member(room_storage_id, self.storage_id).await.unwrap(); + } + RoomRef::Remote { node_id } => { + let req = LeaveRoomReq { + room_id: room_id.as_inner(), + player_id: self.player_id.as_inner(), + }; + self.cluster_client.leave_room(node_id, req).await.unwrap(); + } + } } let update = Updates::RoomLeft { room_id, @@ -643,6 +673,15 @@ impl Player { created_at: &*created_at.to_rfc3339(), }; self.cluster_client.send_room_message(*node_id, req).await.unwrap(); + self.broadcasting + .broadcast( + &self.players, + room_id.clone(), + self.player_id.clone(), + body.clone(), + created_at.clone(), + ) + .await; } } let update = Updates::NewMessage { diff --git a/crates/lavina-core/src/repo/room.rs b/crates/lavina-core/src/repo/room.rs index 38de47d..d831b74 100644 --- a/crates/lavina-core/src/repo/room.rs +++ b/crates/lavina-core/src/repo/room.rs @@ -48,4 +48,19 @@ impl Storage { Ok(()) } + + pub async fn create_or_retrieve_room_id_by_name(&self, name: &str) -> Result { + let mut executor = self.conn.lock().await; + let res: (u32,) = sqlx::query_as( + "insert into rooms(name, topic) + values (?, '') + on conflict(name) do nothing + returning id;", + ) + .bind(name) + .fetch_one(&mut *executor) + .await?; + + Ok(res.0) + } } diff --git a/crates/lavina-core/src/repo/user.rs b/crates/lavina-core/src/repo/user.rs index d836b8f..a27c245 100644 --- a/crates/lavina-core/src/repo/user.rs +++ b/crates/lavina-core/src/repo/user.rs @@ -14,6 +14,21 @@ impl Storage { Ok(res.map(|(id,)| id)) } + pub async fn create_or_retrieve_user_id_by_name(&self, name: &str) -> Result { + let mut executor = self.conn.lock().await; + let res: (u32,) = sqlx::query_as( + "insert into users(name) + values (?) + on conflict(name) do update set name = excluded.name + returning id;", + ) + .bind(name) + .fetch_one(&mut *executor) + .await?; + + Ok(res.0) + } + pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result> { let mut executor = self.conn.lock().await; let res: Vec<(String,)> = sqlx::query_as( diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index 17a463b..599b52c 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -60,7 +60,7 @@ impl RoomRegistry { } #[tracing::instrument(skip(self), name = "RoomRegistry::get_or_create_room")] - pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result { + pub async fn get_or_create_room(&self, room_id: RoomId) -> Result { let mut inner = self.0.write().await; if let Some(room_handle) = inner.get_or_load_room(&room_id).await? { Ok(room_handle.clone()) diff --git a/src/http.rs b/src/http.rs index 633ae14..4d36181 100644 --- a/src/http.rs +++ b/src/http.rs @@ -13,16 +13,16 @@ use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use lavina_core::auth::UpdatePasswordResult; -use lavina_core::clustering::{BroadcastMessageReq, SendMessageReq}; use lavina_core::player::{PlayerId, PlayerRegistry, SendMessageResult}; use lavina_core::prelude::*; use lavina_core::repo::Storage; use lavina_core::room::{RoomId, RoomRegistry}; use lavina_core::terminator::Terminator; -use lavina_core::{clustering, LavinaCore}; - +use lavina_core::LavinaCore; use mgmt_api::*; +mod clustering; + type HttpResult = std::result::Result; #[derive(Deserialize, Debug)] @@ -91,11 +91,7 @@ async fn route( (&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, core).await.or5xx(), (&Method::POST, rooms::paths::SEND_MESSAGE) => endpoint_send_room_message(request, core).await.or5xx(), (&Method::POST, rooms::paths::SET_TOPIC) => endpoint_set_room_topic(request, core).await.or5xx(), - (&Method::POST, clustering::paths::ADD_MESSAGE) => endpoint_cluster_add_message(request, core).await.or5xx(), - (&Method::POST, clustering::paths::BROADCAST_MESSAGE) => { - endpoint_cluster_broadcast_message(request, core).await.or5xx() - } - _ => endpoint_not_found(), + _ => clustering::route(core, storage, request).await.unwrap_or_else(endpoint_not_found), }; Ok(res) } @@ -211,68 +207,6 @@ async fn endpoint_set_room_topic( Ok(empty_204_request()) } -#[tracing::instrument(skip_all, name = "endpoint_cluster_add_message")] -async fn endpoint_cluster_add_message( - request: Request, - core: &LavinaCore, -) -> Result>> { - let str = request.collect().await?.to_bytes(); - let Ok(req) = serde_json::from_slice::(&str[..]) else { - return Ok(malformed_request()); - }; - tracing::info!("Incoming request: {:?}", &req); - let Ok(created_at) = chrono::DateTime::parse_from_rfc3339(req.created_at) else { - dbg!(&req.created_at); - return Ok(malformed_request()); - }; - let Ok(room_id) = RoomId::from(req.room_id) else { - dbg!(&req.room_id); - return Ok(room_not_found()); - }; - let Ok(player_id) = PlayerId::from(req.player_id) else { - dbg!(&req.player_id); - return Ok(player_not_found()); - }; - let Some(room_handle) = core.rooms.get_room(&room_id).await else { - dbg!(&room_id); - return Ok(room_not_found()); - }; - room_handle.send_message(&player_id, req.message.into(), created_at.to_utc()).await; - Ok(empty_204_request()) -} - -#[tracing::instrument(skip_all, name = "endpoint_cluster_broadcast_message")] -async fn endpoint_cluster_broadcast_message( - request: Request, - core: &LavinaCore, -) -> Result>> { - let str = request.collect().await?.to_bytes(); - let Ok(req) = serde_json::from_slice::(&str[..]) else { - return Ok(malformed_request()); - }; - let Ok(created_at) = chrono::DateTime::parse_from_rfc3339(req.created_at) else { - return Ok(malformed_request()); - }; - let Ok(room_id) = RoomId::from(req.room_id) else { - return Ok(room_not_found()); - }; - let Ok(author_id) = PlayerId::from(req.author_id) else { - return Ok(player_not_found()); - }; - let broadcasting = core.broadcasting.0.lock().await; - broadcasting - .broadcast( - &core.players, - room_id, - author_id, - req.message.into(), - created_at.to_utc(), - ) - .await; - drop(broadcasting); - Ok(empty_204_request()) -} - fn endpoint_not_found() -> Response> { let payload = ErrorResponse { code: errors::INVALID_PATH, diff --git a/src/http/clustering.rs b/src/http/clustering.rs new file mode 100644 index 0000000..df2605c --- /dev/null +++ b/src/http/clustering.rs @@ -0,0 +1,78 @@ +use http_body_util::{BodyExt, Full}; +use hyper::body::Bytes; +use hyper::{Method, Request, Response}; + +use super::Or5xx; +use crate::http::{empty_204_request, malformed_request, player_not_found, room_not_found}; +use lavina_core::clustering::room::{paths, JoinRoomReq, SendMessageReq}; +use lavina_core::player::PlayerId; +use lavina_core::repo::Storage; +use lavina_core::room::RoomId; +use lavina_core::LavinaCore; + +pub async fn route( + core: &LavinaCore, + storage: &Storage, + request: Request, +) -> Option>> { + match (request.method(), request.uri().path()) { + (&Method::POST, paths::JOIN) => Some(endpoint_cluster_join_room(request, core, storage).await.or5xx()), + (&Method::POST, paths::ADD_MESSAGE) => Some(endpoint_cluster_add_message(request, core).await.or5xx()), + _ => None, + } +} + +#[tracing::instrument(skip_all, name = "endpoint_cluster_join_room")] +async fn endpoint_cluster_join_room( + request: Request, + core: &LavinaCore, + storage: &Storage, +) -> lavina_core::prelude::Result>> { + let str = request.collect().await?.to_bytes(); + let Ok(req) = serde_json::from_slice::(&str[..]) else { + return Ok(malformed_request()); + }; + tracing::info!("Incoming request: {:?}", &req); + let Ok(room_id) = RoomId::from(req.room_id) else { + dbg!(&req.room_id); + return Ok(room_not_found()); + }; + let Ok(player_id) = PlayerId::from(req.player_id) else { + dbg!(&req.player_id); + return Ok(player_not_found()); + }; + let room_handle = core.rooms.get_or_create_room(room_id).await.unwrap(); + let storage_id = storage.create_or_retrieve_user_id_by_name(req.player_id).await?; + room_handle.add_member(&player_id, storage_id).await; + Ok(empty_204_request()) +} + +#[tracing::instrument(skip_all, name = "endpoint_cluster_add_message")] +async fn endpoint_cluster_add_message( + request: Request, + core: &LavinaCore, +) -> lavina_core::prelude::Result>> { + let str = request.collect().await?.to_bytes(); + let Ok(req) = serde_json::from_slice::(&str[..]) else { + return Ok(malformed_request()); + }; + tracing::info!("Incoming request: {:?}", &req); + let Ok(created_at) = chrono::DateTime::parse_from_rfc3339(req.created_at) else { + dbg!(&req.created_at); + return Ok(malformed_request()); + }; + let Ok(room_id) = RoomId::from(req.room_id) else { + dbg!(&req.room_id); + return Ok(room_not_found()); + }; + let Ok(player_id) = PlayerId::from(req.player_id) else { + dbg!(&req.player_id); + return Ok(player_not_found()); + }; + let Some(room_handle) = core.rooms.get_room(&room_id).await else { + dbg!(&room_id); + return Ok(room_not_found()); + }; + room_handle.send_message(&player_id, req.message.into(), created_at.to_utc()).await; + Ok(empty_204_request()) +}