From a325c7307cc75ee39a597fce80f01cf1371e23dc Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sat, 6 Apr 2024 22:34:11 +0000 Subject: [PATCH 01/37] irc: improve tests and remove tail space in chan member list --- crates/projection-irc/tests/lib.rs | 204 ++++++++++++++++++++++------- crates/proto-irc/src/server.rs | 7 +- 2 files changed, 162 insertions(+), 49 deletions(-) diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index c3efd2a..240a067 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -48,6 +48,31 @@ impl<'a> TestScope<'a> { Ok(()) } + async fn expect_that(&mut self, validate: impl FnOnce(&str) -> bool) -> Result<()> { + let len = tokio::time::timeout(self.timeout, read_irc_message(&mut self.reader, &mut self.buffer)).await??; + let msg = std::str::from_utf8(&self.buffer[..len - 2])?; + if !validate(msg) { + return Err(anyhow!("unexpected message: {:?}", msg)); + } + self.buffer.clear(); + Ok(()) + } + + async fn expect_server_introduction(&mut self, nick: &str) -> Result<()> { + self.expect(&format!(":testserver 001 {nick} :Welcome to testserver Server")).await?; + self.expect(&format!(":testserver 002 {nick} :Welcome to testserver Server")).await?; + self.expect(&format!(":testserver 003 {nick} :Welcome to testserver Server")).await?; + self.expect(&format!( + ":testserver 004 {nick} testserver {APP_VERSION} r CFILPQbcefgijklmnopqrstvz" + )) + .await?; + self.expect(&format!( + ":testserver 005 {nick} CHANTYPES=# :are supported by this server" + )) + .await?; + Ok(()) + } + async fn expect_eof(&mut self) -> Result<()> { let mut buf = [0; 1]; let len = tokio::time::timeout(self.timeout, self.reader.read(&mut buf)).await??; @@ -113,18 +138,7 @@ async fn scenario_basic() -> Result<()> { s.send("PASS password").await?; s.send("NICK tester").await?; s.send("USER UserName 0 * :Real Name").await?; - s.expect(":testserver 001 tester :Welcome to testserver Server").await?; - s.expect(":testserver 002 tester :Welcome to testserver Server").await?; - s.expect(":testserver 003 tester :Welcome to testserver Server").await?; - s.expect( - format!( - ":testserver 004 tester testserver {} r CFILPQbcefgijklmnopqrstvz", - &APP_VERSION - ) - .as_str(), - ) - .await?; - s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?; + s.expect_server_introduction("tester").await?; s.expect_nothing().await?; s.send("QUIT :Leaving").await?; s.expect(":testserver ERROR :Leaving the server").await?; @@ -138,6 +152,133 @@ async fn scenario_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_force_join_msg() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; + + let mut stream1 = TcpStream::connect(server.server.addr).await?; + let mut s1 = TestScope::new(&mut stream1); + let mut stream2 = TcpStream::connect(server.server.addr).await?; + let mut s2 = TestScope::new(&mut stream2); + + s1.send("PASS password").await?; + s1.send("NICK tester").await?; + s1.send("USER UserName 0 * :Real Name").await?; + s1.expect_server_introduction("tester").await?; + s1.expect_nothing().await?; + + s2.send("PASS password").await?; + s2.send("NICK tester").await?; + s2.send("USER UserName 0 * :Real Name").await?; + s2.expect_server_introduction("tester").await?; + s2.expect_nothing().await?; + + // We join the channel from the first connection + s1.send("JOIN #test").await?; + s1.expect(":tester JOIN #test").await?; + s1.expect(":testserver 332 tester #test :New room").await?; + s1.expect(":testserver 353 tester = #test :tester").await?; + s1.expect(":testserver 366 tester #test :End of /NAMES list").await?; + + // And the second connection should receive the JOIN message (forced JOIN) + s2.expect(":tester JOIN #test").await?; + s2.expect(":testserver 332 tester #test :New room").await?; + s2.expect(":testserver 353 tester = #test :tester").await?; + s2.expect(":testserver 366 tester #test :End of /NAMES list").await?; + + // We send a message to the channel from the second connection + s2.send("PRIVMSG #test :Hello").await?; + // We should not receive an acknowledgement from the server + s2.expect_nothing().await?; + // But we should receive this message from the first connection + s1.expect(":tester PRIVMSG #test :Hello").await?; + + s1.send("QUIT :Leaving").await?; + s1.expect(":testserver ERROR :Leaving the server").await?; + s1.expect_eof().await?; + + // Closing a connection does not kick you from the channel on a different connection + s2.expect_nothing().await?; + + s2.send("QUIT :Leaving").await?; + s2.expect(":testserver ERROR :Leaving the server").await?; + s2.expect_eof().await?; + + stream1.shutdown().await?; + stream2.shutdown().await?; + + // wrap up + + server.server.terminate().await?; + Ok(()) +} + +#[tokio::test] +async fn scenario_two_users() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester1").await?; + server.storage.set_password("tester1", "password").await?; + server.storage.create_user("tester2").await?; + server.storage.set_password("tester2", "password").await?; + + let mut stream1 = TcpStream::connect(server.server.addr).await?; + let mut s1 = TestScope::new(&mut stream1); + let mut stream2 = TcpStream::connect(server.server.addr).await?; + let mut s2 = TestScope::new(&mut stream2); + + s1.send("PASS password").await?; + s1.send("NICK tester1").await?; + s1.send("USER UserName 0 * :Real Name").await?; + s1.expect_server_introduction("tester1").await?; + s1.expect_nothing().await?; + + s2.send("PASS password").await?; + s2.send("NICK tester2").await?; + s2.send("USER UserName 0 * :Real Name").await?; + s2.expect_server_introduction("tester2").await?; + s2.expect_nothing().await?; + + // Join the channel from the first user + s1.send("JOIN #test").await?; + s1.expect(":tester1 JOIN #test").await?; + s1.expect(":testserver 332 tester1 #test :New room").await?; + s1.expect(":testserver 353 tester1 = #test :tester1").await?; + s1.expect(":testserver 366 tester1 #test :End of /NAMES list").await?; + // Then join the channel from the second user + s2.send("JOIN #test").await?; + s2.expect(":tester2 JOIN #test").await?; + s2.expect(":testserver 332 tester2 #test :New room").await?; + s2.expect_that(|msg| { + msg == ":testserver 353 tester2 = #test :tester1 tester2" + || msg == ":testserver 353 tester2 = #test :tester2 tester1" + }) + .await?; + s2.expect(":testserver 366 tester2 #test :End of /NAMES list").await?; + // The first user should receive the JOIN message from the second user + s1.expect(":tester2 JOIN #test").await?; + s1.expect_nothing().await?; + s2.expect_nothing().await?; + // Send a message from the second user + s2.send("PRIVMSG #test :Hello").await?; + // The first user should receive the message + s1.expect(":tester2 PRIVMSG #test :Hello").await?; + // Leave the channel from the first user + // TODO implement irc PART command + // s1.send("PART #test").await?; + // s1.expect(":tester1 PART #test").await?; + // The second user should receive the PART message + // s2.expect(":tester1 PART #test").await?; + Ok(()) +} + /* IRC SASL doc: https://ircv3.net/specs/extensions/sasl-3.1.html AUTHENTICATE doc: https://modern.ircdocs.horse/#authenticate-message @@ -168,18 +309,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> { s.send("CAP END").await?; - s.expect(":testserver 001 tester :Welcome to testserver Server").await?; - s.expect(":testserver 002 tester :Welcome to testserver Server").await?; - s.expect(":testserver 003 tester :Welcome to testserver Server").await?; - s.expect( - format!( - ":testserver 004 tester testserver {} r CFILPQbcefgijklmnopqrstvz", - &APP_VERSION - ) - .as_str(), - ) - .await?; - s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?; + s.expect_server_introduction("tester").await?; s.expect_nothing().await?; s.send("QUIT :Leaving").await?; s.expect(":testserver ERROR :Leaving the server").await?; @@ -217,18 +347,7 @@ async fn scenario_cap_short_negotiation() -> Result<()> { s.send("CAP END").await?; - s.expect(":testserver 001 tester :Welcome to testserver Server").await?; - s.expect(":testserver 002 tester :Welcome to testserver Server").await?; - s.expect(":testserver 003 tester :Welcome to testserver Server").await?; - s.expect( - format!( - ":testserver 004 tester testserver {} r CFILPQbcefgijklmnopqrstvz", - &APP_VERSION - ) - .as_str(), - ) - .await?; - s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?; + s.expect_server_introduction("tester").await?; s.expect_nothing().await?; s.send("QUIT :Leaving").await?; s.expect(":testserver ERROR :Leaving the server").await?; @@ -272,18 +391,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> { s.send("CAP END").await?; - s.expect(":testserver 001 tester :Welcome to testserver Server").await?; - s.expect(":testserver 002 tester :Welcome to testserver Server").await?; - s.expect(":testserver 003 tester :Welcome to testserver Server").await?; - s.expect( - format!( - ":testserver 004 tester testserver {} r CFILPQbcefgijklmnopqrstvz", - &APP_VERSION - ) - .as_str(), - ) - .await?; - s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?; + s.expect_server_introduction("tester").await?; s.expect_nothing().await?; s.send("QUIT :Leaving").await?; s.expect(":testserver ERROR :Leaving the server").await?; diff --git a/crates/proto-irc/src/server.rs b/crates/proto-irc/src/server.rs index 6e1bd66..c751e23 100644 --- a/crates/proto-irc/src/server.rs +++ b/crates/proto-irc/src/server.rs @@ -317,10 +317,15 @@ impl ServerMessageBody { writer.write_all(b" = ").await?; chan.write_async(writer).await?; writer.write_all(b" :").await?; - for member in members { + { + let member = &members.head; writer.write_all(member.prefix.to_string().as_bytes()).await?; writer.write_all(member.nick.as_bytes()).await?; + } + for member in &members.tail { writer.write_all(b" ").await?; + writer.write_all(member.prefix.to_string().as_bytes()).await?; + writer.write_all(member.nick.as_bytes()).await?; } } ServerMessageBody::N366NamesReplyEnd { client, chan } => { From fd437df67ee5b6ae01dcb61c033ad581402685cd Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sat, 6 Apr 2024 09:23:12 +0000 Subject: [PATCH 02/37] xmpp: buffer data outgoing over a TLS-stream --- crates/projection-xmpp/src/lib.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index a2a0a5b..3b71b84 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -185,7 +185,7 @@ async fn handle_socket( let (a, b) = tokio::io::split(new_stream); let mut xml_reader = NsReader::from_reader(BufReader::new(a)); - let mut xml_writer = Writer::new(b); + let mut xml_writer = Writer::new(BufWriter::new(b)); pin!(termination); select! { @@ -216,7 +216,7 @@ async fn handle_socket( } let a = xml_reader.into_inner().into_inner(); - let b = xml_writer.into_inner(); + let b = xml_writer.into_inner().into_inner(); a.unsplit(b).shutdown().await?; Ok(()) } @@ -284,6 +284,7 @@ async fn socket_auth( let auth: proto_xmpp::sasl::Auth = proto_xmpp::sasl::Auth::parse(xml_reader, reader_buf).await?; proto_xmpp::sasl::Success.write_xml(xml_writer).await?; + xml_writer.get_mut().flush().await?; match AuthBody::from_str(&auth.body) { Ok(logopass) => { From ab61e975bf1912e3c012d90f882e5571e7dab963 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sat, 6 Apr 2024 15:28:46 +0000 Subject: [PATCH 03/37] xmpp: send correct errors on unknown iqs --- crates/projection-xmpp/src/iq.rs | 6 +++-- crates/proto-xmpp/src/client.rs | 38 ++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/crates/projection-xmpp/src/iq.rs b/crates/projection-xmpp/src/iq.rs index 3ad0ae5..9c50ecc 100644 --- a/crates/projection-xmpp/src/iq.rs +++ b/crates/projection-xmpp/src/iq.rs @@ -4,7 +4,7 @@ use quick_xml::events::Event; use lavina_core::room::RoomRegistry; use proto_xmpp::bind::{BindResponse, Jid, Name, Resource, Server}; -use proto_xmpp::client::{Iq, IqType}; +use proto_xmpp::client::{Iq, IqError, IqErrorType, IqType}; use proto_xmpp::disco::{Feature, Identity, InfoQuery, Item, ItemQuery}; use proto_xmpp::roster::RosterQuery; use proto_xmpp::session::Session; @@ -79,7 +79,9 @@ impl<'a> XmppConnection<'a> { id: iq.id, to: None, r#type: IqType::Error, - body: (), + body: IqError { + r#type: IqErrorType::Cancel, + }, }; req.serialize(output); } diff --git a/crates/proto-xmpp/src/client.rs b/crates/proto-xmpp/src/client.rs index 8276d7c..4943283 100644 --- a/crates/proto-xmpp/src/client.rs +++ b/crates/proto-xmpp/src/client.rs @@ -255,6 +255,44 @@ impl MessageType { } } +/// Error response to an IQ request. +/// +/// https://xmpp.org/rfcs/rfc6120.html#stanzas-error +pub struct IqError { + pub r#type: IqErrorType, +} + +pub enum IqErrorType { + /// Retry after providing credentials + Auth, + /// Do not retry (the error cannot be remedied) + Cancel, + /// Proceed (the condition was only a warning) + Continue, + /// Retry after changing the data sent + Modify, + /// Retry after waiting (the error is temporary) + Wait, +} +impl IqErrorType { + pub fn as_str(&self) -> &'static str { + match self { + IqErrorType::Auth => "auth", + IqErrorType::Cancel => "cancel", + IqErrorType::Continue => "continue", + IqErrorType::Modify => "modify", + IqErrorType::Wait => "wait", + } + } +} + +impl ToXml for IqError { + fn serialize(&self, events: &mut Vec>) { + let bytes = BytesStart::new(format!(r#"error xmlns="{}" type="{}""#, XMLNS, self.r#type.as_str())); + events.push(Event::Empty(bytes)); + } +} + #[derive(PartialEq, Eq, Debug)] pub struct Iq { pub from: Option, From adece11fef464b3c33c00687e6a4b3f14a8e115d Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Thu, 4 Apr 2024 17:49:03 +0000 Subject: [PATCH 04/37] xmpp: make xml-headers optional in the c2s stream --- crates/projection-xmpp/src/lib.rs | 31 ++------------ crates/projection-xmpp/tests/lib.rs | 63 +++++++++++++++++++++++++++++ crates/proto-xmpp/src/stream.rs | 12 +++++- 3 files changed, 77 insertions(+), 29 deletions(-) diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 3b71b84..e3cb4e5 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -229,7 +229,6 @@ async fn socket_force_tls( use proto_xmpp::tls::*; let xml_reader = &mut NsReader::from_reader(reader); let xml_writer = &mut Writer::new(writer); - read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let event = Event::Decl(BytesDecl::new("1.0", None, None)); @@ -261,7 +260,6 @@ async fn socket_auth( reader_buf: &mut Vec, storage: &mut Storage, ) -> Result { - read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; @@ -328,7 +326,6 @@ async fn socket_final( user_handle: &mut PlayerConnection, rooms: &RoomRegistry, ) -> Result<()> { - read_xml_header(xml_reader, reader_buf).await?; let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; @@ -420,7 +417,7 @@ struct XmppConnection<'a> { impl<'a> XmppConnection<'a> { async fn handle_packet(&mut self, output: &mut Vec>, packet: ClientPacket) -> Result { let res = match packet { - proto::ClientPacket::Iq(iq) => { + ClientPacket::Iq(iq) => { self.handle_iq(output, iq).await; false } @@ -428,11 +425,11 @@ impl<'a> XmppConnection<'a> { self.handle_message(output, m).await?; false } - proto::ClientPacket::Presence(p) => { + ClientPacket::Presence(p) => { self.handle_presence(output, p).await?; false } - proto::ClientPacket::StreamEnd => { + ClientPacket::StreamEnd => { ServerStreamEnd.serialize(output); true } @@ -440,25 +437,3 @@ impl<'a> XmppConnection<'a> { Ok(res) } } - -async fn read_xml_header( - xml_reader: &mut NsReader<(impl AsyncBufRead + Unpin)>, - reader_buf: &mut Vec, -) -> Result<()> { - if let Event::Decl(bytes) = xml_reader.read_event_into_async(reader_buf).await? { - // this is header - if let Some(encoding) = bytes.encoding() { - let encoding = encoding?; - if &*encoding == b"UTF-8" { - Ok(()) - } else { - Err(anyhow!("Unsupported encoding: {encoding:?}")) - } - } else { - // Err(fail("No XML encoding provided")) - Ok(()) - } - } else { - Err(anyhow!("Expected XML header")) - } -} diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 0dae478..9dfae4c 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -187,6 +187,69 @@ async fn scenario_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_basic_without_headers() -> Result<()> { + tracing_subscriber::fmt::try_init(); + let config = ServerConfig { + listen_on: "127.0.0.1:0".parse().unwrap(), + cert: "tests/certs/xmpp.pem".parse().unwrap(), + key: "tests/certs/xmpp.key".parse().unwrap(), + }; + let mut metrics = MetricsRegistry::new(); + let mut storage = Storage::open(StorageConfig { + db_path: ":memory:".into(), + }) + .await?; + let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); + let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap(); + + // test scenario + + storage.create_user("tester").await?; + storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.addr).await?; + let mut s = TestScope::new(&mut stream); + tracing::info!("TCP connection established"); + + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); + let buffer = s.buffer; + tracing::info!("TLS feature negotiation complete"); + + let connector = TlsConnector::from(Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(Arc::new(IgnoreCertVerification)) + .with_no_client_auth(), + )); + tracing::info!("Initiating TLS connection..."); + let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?; + tracing::info!("TLS connection established"); + + let mut s = TestScopeTls::new(&mut stream, buffer); + + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + + stream.shutdown().await?; + + // wrap up + + server.terminate().await?; + Ok(()) +} + #[tokio::test] async fn terminate_socket() -> Result<()> { tracing_subscriber::fmt::try_init(); diff --git a/crates/proto-xmpp/src/stream.rs b/crates/proto-xmpp/src/stream.rs index d85df07..8f46f31 100644 --- a/crates/proto-xmpp/src/stream.rs +++ b/crates/proto-xmpp/src/stream.rs @@ -24,7 +24,17 @@ impl ClientStreamStart { reader: &mut NsReader, buf: &mut Vec, ) -> Result { - let incoming = skip_text!(reader, buf); + let mut incoming = skip_text!(reader, buf); + if let Event::Decl(bytes) = incoming { + // this is header + if let Some(encoding) = bytes.encoding() { + let encoding = encoding?; + if &*encoding != b"UTF-8" { + return Err(anyhow!("Unsupported encoding: {encoding:?}")); + } + } + incoming = skip_text!(reader, buf); + } if let Event::Start(e) = incoming { let (ns, local) = reader.resolve_element(e.name()); if ns != ResolveResult::Bound(Namespace(XMLNS.as_bytes())) { From 36b0d50d5103f6bd86ef52680ce7f001ae3b1d7d Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sat, 6 Apr 2024 23:01:24 +0000 Subject: [PATCH 05/37] irc: allow PART without a reason --- crates/projection-irc/tests/lib.rs | 7 +++---- crates/proto-irc/src/client.rs | 25 +++++++++++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 240a067..36b9040 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -271,11 +271,10 @@ async fn scenario_two_users() -> Result<()> { // The first user should receive the message s1.expect(":tester2 PRIVMSG #test :Hello").await?; // Leave the channel from the first user - // TODO implement irc PART command - // s1.send("PART #test").await?; - // s1.expect(":tester1 PART #test").await?; + s1.send("PART #test").await?; + s1.expect(":tester1 PART #test").await?; // The second user should receive the PART message - // s2.expect(":tester1 PART #test").await?; + s2.expect(":tester1 PART #test").await?; Ok(()) } diff --git a/crates/proto-irc/src/client.rs b/crates/proto-irc/src/client.rs index 66cf107..a692e92 100644 --- a/crates/proto-irc/src/client.rs +++ b/crates/proto-irc/src/client.rs @@ -49,7 +49,7 @@ pub enum ClientMessage { }, Part { chan: Chan, - message: Str, + message: Option, }, /// `PRIVMSG :` PrivateMessage { @@ -194,14 +194,20 @@ fn client_message_topic(input: &str) -> IResult<&str, ClientMessage> { fn client_message_part(input: &str) -> IResult<&str, ClientMessage> { let (input, _) = tag("PART ")(input)?; let (input, chan) = chan(input)?; - let (input, _) = tag(" ")(input)?; + let (input, t) = opt(tag(" "))(input)?; + match t { + Some(_) => (), + None => { + return Ok((input, ClientMessage::Part { chan, message: None })); + } + } let (input, r) = opt(tag(":"))(input)?; let (input, message) = match r { Some(_) => token(input)?, None => receiver(input)?, }; - let message = message.into(); + let message = Some(message.into()); Ok((input, ClientMessage::Part { chan, message })) } @@ -369,7 +375,18 @@ mod test { let input = "PART #chan :Pokasiki !!!"; let expected = ClientMessage::Part { chan: Chan::Global("chan".into()), - message: "Pokasiki !!!".into(), + message: Some("Pokasiki !!!".into()), + }; + + let result = client_message(input); + assert_matches!(result, Ok(result) => assert_eq!(expected, result)); + } + #[test] + fn test_client_message_part_empty() { + let input = "PART #chan"; + let expected = ClientMessage::Part { + chan: Chan::Global("chan".into()), + message: None, }; let result = client_message(input); From 8b099f9be27de46d9ae3e188a3ca9f3e2f403bba Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sun, 7 Apr 2024 12:06:23 +0000 Subject: [PATCH 06/37] xmpp: fix handling of the `bind` iq --- crates/projection-xmpp/src/iq.rs | 4 ++-- crates/projection-xmpp/src/lib.rs | 18 ++++++++++++++---- crates/proto-xmpp/src/bind.rs | 3 +++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/crates/projection-xmpp/src/iq.rs b/crates/projection-xmpp/src/iq.rs index 9c50ecc..01135b1 100644 --- a/crates/projection-xmpp/src/iq.rs +++ b/crates/projection-xmpp/src/iq.rs @@ -24,9 +24,9 @@ impl<'a> XmppConnection<'a> { to: None, r#type: IqType::Result, body: BindResponse(Jid { - name: Some(Name("darova".into())), + name: Some(self.user.xmpp_name.clone()), server: Server("localhost".into()), - resource: Some(Resource("kek".into())), + resource: Some(self.user.xmpp_resource.clone()), }), }; req.serialize(output); diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index e3cb4e5..8474e8f 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -52,9 +52,17 @@ struct LoadedConfig { } struct Authenticated { + /// Identifier of the authenticated player. + /// + /// Used when communicating with lavina-core on behalf of the player. player_id: PlayerId, + /// The user's XMPP name. + /// + /// Used in `to` and `from` fields of XMPP messages. xmpp_name: Name, + /// The resource given to this user by the server. xmpp_resource: Resource, + /// The resource used by this user when joining MUCs. xmpp_muc_name: Resource, } @@ -307,11 +315,13 @@ async fn socket_auth( return Err(fail("passwords do not match")); } + let name: Str = name.as_str().into(); + Ok(Authenticated { - player_id: PlayerId::from(name.as_str())?, - xmpp_name: Name(name.to_string().into()), - xmpp_resource: Resource(name.to_string().into()), - xmpp_muc_name: Resource(name.to_string().into()), + player_id: PlayerId::from(name.clone())?, + xmpp_name: Name(name.clone()), + xmpp_resource: Resource(name.clone()), + xmpp_muc_name: Resource(name.clone()), }) } Err(e) => return Err(e), diff --git a/crates/proto-xmpp/src/bind.rs b/crates/proto-xmpp/src/bind.rs index 9984ae5..41c9e45 100644 --- a/crates/proto-xmpp/src/bind.rs +++ b/crates/proto-xmpp/src/bind.rs @@ -11,12 +11,15 @@ pub const XMLNS: &'static str = "urn:ietf:params:xml:ns:xmpp-bind"; // TODO remove `pub` in newtypes, introduce validation +/// Name (node identifier) of an XMPP entity. Placed before the `@` in a JID. #[derive(PartialEq, Eq, Debug, Clone)] pub struct Name(pub Str); +/// Server name of an XMPP entity. Placed after the `@` and before the `/` in a JID. #[derive(PartialEq, Eq, Debug, Clone)] pub struct Server(pub Str); +/// Resource of an XMPP entity. Placed after the `/` in a JID. #[derive(PartialEq, Eq, Debug, Clone)] pub struct Resource(pub Str); From cccc05afe9dee47bbe2e825f91e5c03fff282cbd Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Thu, 11 Apr 2024 23:08:09 +0200 Subject: [PATCH 07/37] xmpp: ignore text elements with spaces at the stream root --- crates/projection-xmpp/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 8474e8f..c659e5b 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -372,7 +372,7 @@ async fn socket_final( res = &mut next_xml_event => 's: { let (ns, event) = res?; if let Event::Text(ref e) = event { - if e.iter().all(|x| *x == 0xA) { + if e.iter().all(|x| *x == b'\n' || *x == b' ') { break 's true; } } From fd694cd75cc43342b9548221de688fe3b34e89f4 Mon Sep 17 00:00:00 2001 From: Mikhail Date: Fri, 12 Apr 2024 21:32:21 +0000 Subject: [PATCH 08/37] Add message timestamps (#41) Resolves #38 Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/41 Co-authored-by: Mikhail Co-committed-by: Mikhail --- .pre-commit-config.yaml | 21 ++++++ Cargo.lock | 68 +++++++++++++++++++ crates/lavina-core/Cargo.toml | 1 + .../migrations/2_created_at_for_messages.sql | 1 + crates/lavina-core/src/repo/mod.rs | 5 +- docs/cheatsheet.md | 7 +- 6 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 crates/lavina-core/migrations/2_created_at_for_messages.sql diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..0393234 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-toml + - id: end-of-file-fixer + - id: fix-byte-order-marker + - id: mixed-line-ending + - id: trailing-whitespace + + - repo: local + hooks: + - id: fmt + name: fmt + description: Format + entry: cargo fmt + language: system + args: + - --all + types: [ rust ] + pass_filenames: false diff --git a/Cargo.lock b/Cargo.lock index e37a3dc..93eac0b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,6 +45,21 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.13" @@ -216,6 +231,20 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets 0.52.4", +] + [[package]] name = "clap" version = "4.5.3" @@ -274,6 +303,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + [[package]] name = "cpufeatures" version = "0.2.12" @@ -729,6 +764,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "idna" version = "0.5.0" @@ -818,6 +876,7 @@ name = "lavina-core" version = "0.0.2-dev" dependencies = [ "anyhow", + "chrono", "prometheus", "serde", "sqlx", @@ -2383,6 +2442,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.4", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/crates/lavina-core/Cargo.toml b/crates/lavina-core/Cargo.toml index 727835c..941b753 100644 --- a/crates/lavina-core/Cargo.toml +++ b/crates/lavina-core/Cargo.toml @@ -10,3 +10,4 @@ serde.workspace = true tokio.workspace = true tracing.workspace = true prometheus.workspace = true +chrono = "0.4.37" diff --git a/crates/lavina-core/migrations/2_created_at_for_messages.sql b/crates/lavina-core/migrations/2_created_at_for_messages.sql new file mode 100644 index 0000000..c11430a --- /dev/null +++ b/crates/lavina-core/migrations/2_created_at_for_messages.sql @@ -0,0 +1 @@ +alter table messages add column created_at text; \ No newline at end of file diff --git a/crates/lavina-core/src/repo/mod.rs b/crates/lavina-core/src/repo/mod.rs index d81ec0c..e9eee6c 100644 --- a/crates/lavina-core/src/repo/mod.rs +++ b/crates/lavina-core/src/repo/mod.rs @@ -87,14 +87,15 @@ impl Storage { return Err(anyhow!("No such user")); }; sqlx::query( - "insert into messages(room_id, id, content, author_id) - values (?, ?, ?, ?); + "insert into messages(room_id, id, content, author_id, created_at) + values (?, ?, ?, ?, ?); update rooms set message_count = message_count + 1 where id = ?;", ) .bind(room_id) .bind(id) .bind(content) .bind(author_id) + .bind(chrono::Utc::now().to_string()) .bind(room_id) .execute(&mut *executor) .await?; diff --git a/docs/cheatsheet.md b/docs/cheatsheet.md index 1ef20d0..ec9f63f 100644 --- a/docs/cheatsheet.md +++ b/docs/cheatsheet.md @@ -8,11 +8,12 @@ Some useful commands for development and testing. Following commands require `OpenSSL` to be installed. It is provided as `openssl` package in Arch Linux. -Generate self-signed TLS certificate: +Generate self-signed TLS certificate. Mind the common name (CN) field, it should match the domain name of the server. +Example for localhost: openssl req -x509 -newkey rsa:4096 -sha256 -days 365 -noenc \ -keyout certs/xmpp.key -out certs/xmpp.pem \ - -subj "/CN=example.com" + -subj "/CN=localhost" Print content of a TLS certificate: @@ -35,4 +36,4 @@ Connecting: Password should be the same as in storage. Example: - /connect -nocap 127.0.0.1 6667 parolchik1 kek \ No newline at end of file + /connect -nocap 127.0.0.1 6667 parolchik1 kek From 0944c449ca4125d20ccfa5df335759f305d333fc Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sat, 13 Apr 2024 02:32:41 +0200 Subject: [PATCH 09/37] xmpp: in integration tests extract server startup code --- crates/projection-xmpp/tests/lib.rs | 114 +++++++++++++--------------- 1 file changed, 53 insertions(+), 61 deletions(-) diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 9dfae4c..7ce583c 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -1,5 +1,4 @@ use std::io::ErrorKind; -use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -20,7 +19,7 @@ use tokio_rustls::TlsConnector; use lavina_core::player::PlayerRegistry; use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::room::RoomRegistry; -use projection_xmpp::{launch, ServerConfig}; +use projection_xmpp::{launch, RunningServer, ServerConfig}; use proto_xmpp::xml::{Continuation, FromXml, Parser}; pub async fn read_irc_message(reader: &mut BufReader>, buf: &mut Vec) -> Result { @@ -122,29 +121,49 @@ impl ServerCertVerifier for IgnoreCertVerification { } } +struct TestServer { + metrics: MetricsRegistry, + storage: Storage, + rooms: RoomRegistry, + players: PlayerRegistry, + server: RunningServer, +} +impl TestServer { + async fn start() -> Result { + let _ = tracing_subscriber::fmt::try_init(); + let config = ServerConfig { + listen_on: "127.0.0.1:0".parse().unwrap(), + cert: "tests/certs/xmpp.pem".parse().unwrap(), + key: "tests/certs/xmpp.key".parse().unwrap(), + }; + let mut metrics = MetricsRegistry::new(); + let mut storage = Storage::open(StorageConfig { + db_path: ":memory:".into(), + }) + .await?; + let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); + let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); + Ok(TestServer { + metrics, + storage, + rooms, + players, + server, + }) + } +} + #[tokio::test] async fn scenario_basic() -> Result<()> { - tracing_subscriber::fmt::try_init(); - let config = ServerConfig { - listen_on: "127.0.0.1:0".parse().unwrap(), - cert: "tests/certs/xmpp.pem".parse().unwrap(), - key: "tests/certs/xmpp.key".parse().unwrap(), - }; - let mut metrics = MetricsRegistry::new(); - let mut storage = Storage::open(StorageConfig { - db_path: ":memory:".into(), - }) - .await?; - let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); - let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap(); + let mut server = TestServer::start().await?; // test scenario - storage.create_user("tester").await?; - storage.set_password("tester", "password").await?; + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; - let mut stream = TcpStream::connect(server.addr).await?; + let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); tracing::info!("TCP connection established"); @@ -169,7 +188,7 @@ async fn scenario_basic() -> Result<()> { .with_no_client_auth(), )); tracing::info!("Initiating TLS connection..."); - let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?; + let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); let mut s = TestScopeTls::new(&mut stream, buffer); @@ -183,33 +202,20 @@ async fn scenario_basic() -> Result<()> { // wrap up - server.terminate().await?; + server.server.terminate().await?; Ok(()) } #[tokio::test] async fn scenario_basic_without_headers() -> Result<()> { - tracing_subscriber::fmt::try_init(); - let config = ServerConfig { - listen_on: "127.0.0.1:0".parse().unwrap(), - cert: "tests/certs/xmpp.pem".parse().unwrap(), - key: "tests/certs/xmpp.key".parse().unwrap(), - }; - let mut metrics = MetricsRegistry::new(); - let mut storage = Storage::open(StorageConfig { - db_path: ":memory:".into(), - }) - .await?; - let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); - let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap(); + let mut server = TestServer::start().await?; // test scenario - storage.create_user("tester").await?; - storage.set_password("tester", "password").await?; + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; - let mut stream = TcpStream::connect(server.addr).await?; + let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); tracing::info!("TCP connection established"); @@ -233,7 +239,7 @@ async fn scenario_basic_without_headers() -> Result<()> { .with_no_client_auth(), )); tracing::info!("Initiating TLS connection..."); - let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?; + let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); let mut s = TestScopeTls::new(&mut stream, buffer); @@ -246,33 +252,20 @@ async fn scenario_basic_without_headers() -> Result<()> { // wrap up - server.terminate().await?; + server.server.terminate().await?; Ok(()) } #[tokio::test] async fn terminate_socket() -> Result<()> { - tracing_subscriber::fmt::try_init(); - let config = ServerConfig { - listen_on: "127.0.0.1:0".parse().unwrap(), - cert: "tests/certs/xmpp.pem".parse().unwrap(), - key: "tests/certs/xmpp.key".parse().unwrap(), - }; - let mut metrics = MetricsRegistry::new(); - let mut storage = Storage::open(StorageConfig { - db_path: ":memory:".into(), - }) - .await?; - let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); - let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap(); - let address: SocketAddr = ("127.0.0.1:0".parse().unwrap()); + let mut server = TestServer::start().await?; + // test scenario - storage.create_user("tester").await?; - storage.set_password("tester", "password").await?; + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; - let mut stream = TcpStream::connect(server.addr).await?; + let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); tracing::info!("TCP connection established"); @@ -288,7 +281,6 @@ async fn terminate_socket() -> Result<()> { assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); s.send(r#""#).await?; assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); - let buffer = s.buffer; let connector = TlsConnector::from(Arc::new( ClientConfig::builder() @@ -298,10 +290,10 @@ async fn terminate_socket() -> Result<()> { )); tracing::info!("Initiating TLS connection..."); - let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?; + let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); - server.terminate().await?; + server.server.terminate().await?; assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof); From 57b6af87326ec68e2698710feacd467efdb25693 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Mon, 15 Apr 2024 00:33:26 +0000 Subject: [PATCH 10/37] xmpp: configurable server hostname (#47) Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/47 --- config.toml | 1 + crates/projection-xmpp/src/iq.rs | 137 +++++++++++++------------ crates/projection-xmpp/src/lib.rs | 31 ++++-- crates/projection-xmpp/src/message.rs | 6 +- crates/projection-xmpp/src/presence.rs | 8 +- crates/projection-xmpp/src/updates.rs | 4 +- crates/projection-xmpp/tests/lib.rs | 1 + docs/running.md | 1 + 8 files changed, 103 insertions(+), 86 deletions(-) diff --git a/config.toml b/config.toml index 6104dce..4765aa0 100644 --- a/config.toml +++ b/config.toml @@ -9,6 +9,7 @@ server_name = "irc.localhost" listen_on = "127.0.0.1:5222" cert = "./certs/xmpp.pem" key = "./certs/xmpp.key" +hostname = "localhost" [storage] db_path = "db.sqlite" diff --git a/crates/projection-xmpp/src/iq.rs b/crates/projection-xmpp/src/iq.rs index 01135b1..6766e19 100644 --- a/crates/projection-xmpp/src/iq.rs +++ b/crates/projection-xmpp/src/iq.rs @@ -25,7 +25,7 @@ impl<'a> XmppConnection<'a> { r#type: IqType::Result, body: BindResponse(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), }; @@ -52,7 +52,7 @@ impl<'a> XmppConnection<'a> { req.serialize(output); } IqClientBody::DiscoInfo(info) => { - let response = disco_info(iq.to.as_deref(), &info); + let response = self.disco_info(iq.to.as_deref(), &info); let req = Iq { from: iq.to, id: iq.id, @@ -63,7 +63,7 @@ impl<'a> XmppConnection<'a> { req.serialize(output); } IqClientBody::DiscoItem(item) => { - let response = disco_items(iq.to.as_deref(), &item, self.rooms).await; + let response = self.disco_items(iq.to.as_deref(), &item, self.rooms).await; let req = Iq { from: iq.to, id: iq.id, @@ -87,78 +87,79 @@ impl<'a> XmppConnection<'a> { } } } -} -fn disco_info(to: Option<&str>, req: &InfoQuery) -> InfoQuery { - let identity; - let feature; - match to { - Some("localhost") => { - identity = vec![Identity { - category: "server".into(), - name: None, - r#type: "im".into(), - }]; - feature = vec![ - Feature::new("http://jabber.org/protocol/disco#info"), - Feature::new("http://jabber.org/protocol/disco#items"), - Feature::new("iq"), - Feature::new("presence"), - ] - } - Some("rooms.localhost") => { - identity = vec![Identity { - category: "conference".into(), - name: Some("Chat rooms".into()), - r#type: "text".into(), - }]; - feature = vec![ - Feature::new("http://jabber.org/protocol/disco#info"), - Feature::new("http://jabber.org/protocol/disco#items"), - Feature::new("http://jabber.org/protocol/muc"), - ] - } - _ => { - identity = vec![]; - feature = vec![]; - } - }; - InfoQuery { - node: None, - identity, - feature, - } -} + fn disco_info(&self, to: Option<&str>, req: &InfoQuery) -> InfoQuery { + let identity; + let feature; -async fn disco_items(to: Option<&str>, req: &ItemQuery, rooms: &RoomRegistry) -> ItemQuery { - let item = match to { - Some("localhost") => { - vec![Item { - jid: Jid { + match to { + Some(r) if r == &*self.hostname => { + identity = vec![Identity { + category: "server".into(), name: None, - server: Server("rooms.localhost".into()), - resource: None, - }, - name: None, - node: None, - }] + r#type: "im".into(), + }]; + feature = vec![ + Feature::new("http://jabber.org/protocol/disco#info"), + Feature::new("http://jabber.org/protocol/disco#items"), + Feature::new("iq"), + Feature::new("presence"), + ] + } + Some(r) if r == &*self.hostname_rooms => { + identity = vec![Identity { + category: "conference".into(), + name: Some("Chat rooms".into()), + r#type: "text".into(), + }]; + feature = vec![ + Feature::new("http://jabber.org/protocol/disco#info"), + Feature::new("http://jabber.org/protocol/disco#items"), + Feature::new("http://jabber.org/protocol/muc"), + ] + } + _ => { + identity = vec![]; + feature = vec![]; + } + }; + InfoQuery { + node: None, + identity, + feature, } - Some("rooms.localhost") => { - let room_list = rooms.get_all_rooms().await; - room_list - .into_iter() - .map(|room_info| Item { + } + + async fn disco_items(&self, to: Option<&str>, req: &ItemQuery, rooms: &RoomRegistry) -> ItemQuery { + let item = match to { + Some(r) if r == &*self.hostname => { + vec![Item { jid: Jid { - name: Some(Name(room_info.id.into_inner())), - server: Server("rooms.localhost".into()), + name: None, + server: Server(self.hostname_rooms.clone()), resource: None, }, name: None, node: None, - }) - .collect() - } - _ => vec![], - }; - ItemQuery { item } + }] + } + Some(r) if r == &*self.hostname_rooms => { + let room_list = rooms.get_all_rooms().await; + room_list + .into_iter() + .map(|room_info| Item { + jid: Jid { + name: Some(Name(room_info.id.into_inner())), + server: Server(self.hostname_rooms.clone()), + resource: None, + }, + name: None, + node: None, + }) + .collect() + } + _ => vec![], + }; + ItemQuery { item } + } } diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index c659e5b..0da3019 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -9,7 +9,6 @@ use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; -use anyhow::anyhow; use futures_util::future::join_all; use prometheus::Registry as MetricsRegistry; use quick_xml::events::{BytesDecl, Event}; @@ -44,6 +43,7 @@ pub struct ServerConfig { pub listen_on: SocketAddr, pub cert: PathBuf, pub key: PathBuf, + pub hostname: Str, } struct LoadedConfig { @@ -125,11 +125,12 @@ pub async fn launch( let players = players.clone(); let rooms = rooms.clone(); let storage = storage.clone(); + let hostname = config.hostname.clone(); let terminator = Terminator::spawn(|termination| { let stopped_tx = stopped_tx.clone(); let loaded_config = loaded_config.clone(); async move { - match handle_socket(loaded_config, stream, &socket_addr, players, rooms, storage, termination).await { + match handle_socket(loaded_config, stream, &socket_addr, players, rooms, storage, hostname, termination).await { Ok(_) => log::info!("Connection terminated"), Err(err) => log::warn!("Connection failed: {err}"), } @@ -164,12 +165,13 @@ pub async fn launch( } async fn handle_socket( - config: Arc, + cert_config: Arc, mut stream: TcpStream, socket_addr: &SocketAddr, mut players: PlayerRegistry, rooms: RoomRegistry, mut storage: Storage, + hostname: Str, termination: Deferred<()>, // TODO use it to stop the connection gracefully ) -> Result<()> { log::info!("Received an XMPP connection from {socket_addr}"); @@ -178,12 +180,12 @@ async fn handle_socket( let mut buf_reader = BufReader::new(reader); let mut buf_writer = BufWriter::new(writer); - socket_force_tls(&mut buf_reader, &mut buf_writer, &mut reader_buf).await?; + socket_force_tls(&mut buf_reader, &mut buf_writer, &mut reader_buf, &hostname).await?; let mut config = tokio_rustls::rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() - .with_single_cert(vec![config.cert.clone()], config.key.clone())?; + .with_single_cert(vec![cert_config.cert.clone()], cert_config.key.clone())?; config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new()); log::debug!("Accepting TLS connection..."); @@ -202,7 +204,7 @@ async fn handle_socket( log::info!("Socket handling was terminated"); return Ok(()) }, - authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage) => { + authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage, &hostname) => { match authenticated { Ok(authenticated) => { let mut connection = players.connect_to_player(authenticated.player_id.clone()).await; @@ -213,6 +215,7 @@ async fn handle_socket( &authenticated, &mut connection, &rooms, + &hostname, ) .await?; }, @@ -233,16 +236,18 @@ async fn socket_force_tls( reader: &mut (impl AsyncBufRead + Unpin), writer: &mut (impl AsyncWrite + Unpin), reader_buf: &mut Vec, + hostname: &Str, ) -> Result<()> { use proto_xmpp::tls::*; let xml_reader = &mut NsReader::from_reader(reader); let xml_writer = &mut Writer::new(writer); + // TODO validate the server hostname received in the stream start let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; let event = Event::Decl(BytesDecl::new("1.0", None, None)); xml_writer.write_event_async(event).await?; let msg = ServerStreamStart { - from: "localhost".into(), + from: hostname.to_string(), lang: "en".into(), id: uuid::Uuid::new_v4().to_string(), version: "1.0".into(), @@ -267,12 +272,14 @@ async fn socket_auth( xml_writer: &mut Writer<(impl AsyncWrite + Unpin)>, reader_buf: &mut Vec, storage: &mut Storage, + hostname: &Str, ) -> Result { + // TODO validate the server hostname received in the stream start let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; ServerStreamStart { - from: "localhost".into(), + from: hostname.to_string(), lang: "en".into(), id: uuid::Uuid::new_v4().to_string(), version: "1.0".into(), @@ -335,12 +342,14 @@ async fn socket_final( authenticated: &Authenticated, user_handle: &mut PlayerConnection, rooms: &RoomRegistry, + hostname: &Str, ) -> Result<()> { + // TODO validate the server hostname received in the stream start let _ = ClientStreamStart::parse(xml_reader, reader_buf).await?; xml_writer.write_event_async(Event::Decl(BytesDecl::new("1.0", None, None))).await?; ServerStreamStart { - from: "localhost".into(), + from: hostname.to_string(), lang: "en".into(), id: uuid::Uuid::new_v4().to_string(), version: "1.0".into(), @@ -366,6 +375,8 @@ async fn socket_final( user: authenticated, user_handle, rooms, + hostname: hostname.clone(), + hostname_rooms: format!("rooms.{}", hostname).into(), }; let should_recreate_xml_future = select! { biased; @@ -422,6 +433,8 @@ struct XmppConnection<'a> { user: &'a Authenticated, user_handle: &'a mut PlayerConnection, rooms: &'a RoomRegistry, + hostname: Str, + hostname_rooms: Str, } impl<'a> XmppConnection<'a> { diff --git a/crates/projection-xmpp/src/message.rs b/crates/projection-xmpp/src/message.rs index 44aab05..a737b2b 100644 --- a/crates/projection-xmpp/src/message.rs +++ b/crates/projection-xmpp/src/message.rs @@ -18,17 +18,17 @@ impl<'a> XmppConnection<'a> { resource: _, }) = m.to { - if server.0.as_ref() == "rooms.localhost" && m.r#type == MessageType::Groupchat { + if server.0.as_ref() == &*self.hostname_rooms && m.r#type == MessageType::Groupchat { self.user_handle.send_message(RoomId::from(name.0.clone())?, m.body.clone().into()).await?; Message::<()> { to: Some(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), from: Some(Jid { name: Some(name), - server: Server("rooms.localhost".into()), + server: Server(self.hostname_rooms.clone()), resource: Some(self.user.xmpp_muc_name.clone()), }), id: m.id, diff --git a/crates/projection-xmpp/src/presence.rs b/crates/projection-xmpp/src/presence.rs index eabf0fd..6f9540e 100644 --- a/crates/projection-xmpp/src/presence.rs +++ b/crates/projection-xmpp/src/presence.rs @@ -16,12 +16,12 @@ impl<'a> XmppConnection<'a> { Presence::<()> { to: Some(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), from: Some(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), ..Default::default() @@ -36,12 +36,12 @@ impl<'a> XmppConnection<'a> { Presence::<()> { to: Some(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), from: Some(Jid { name: Some(name.clone()), - server: Server("rooms.localhost".into()), + server: Server(self.hostname_rooms.clone()), resource: Some(self.user.xmpp_muc_name.clone()), }), ..Default::default() diff --git a/crates/projection-xmpp/src/updates.rs b/crates/projection-xmpp/src/updates.rs index c211be8..0161b3f 100644 --- a/crates/projection-xmpp/src/updates.rs +++ b/crates/projection-xmpp/src/updates.rs @@ -21,12 +21,12 @@ impl<'a> XmppConnection<'a> { Message::<()> { to: Some(Jid { name: Some(self.user.xmpp_name.clone()), - server: Server("localhost".into()), + server: Server(self.hostname.clone()), resource: Some(self.user.xmpp_resource.clone()), }), from: Some(Jid { name: Some(Name(room_id.into_inner().into())), - server: Server("rooms.localhost".into()), + server: Server(self.hostname_rooms.clone()), resource: Some(Resource(author_id.into_inner().into())), }), id: None, diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 7ce583c..cc7e645 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -135,6 +135,7 @@ impl TestServer { listen_on: "127.0.0.1:0".parse().unwrap(), cert: "tests/certs/xmpp.pem".parse().unwrap(), key: "tests/certs/xmpp.key".parse().unwrap(), + hostname: "localhost".into(), }; let mut metrics = MetricsRegistry::new(); let mut storage = Storage::open(StorageConfig { diff --git a/docs/running.md b/docs/running.md index 61f3067..ad422a1 100644 --- a/docs/running.md +++ b/docs/running.md @@ -19,6 +19,7 @@ server_name = "irc.localhost" listen_on = "127.0.0.1:5222" cert = "./certs/xmpp.pem" key = "./certs/xmpp.key" +hostname = "localhost" [storage] db_path = "db.sqlite" From 0105a5b710fc9779fd4d1111676cbd124db4ba3c Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Mon, 15 Apr 2024 09:06:10 +0000 Subject: [PATCH 11/37] persistent memberships (#49) Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/49 --- crates/lavina-core/src/player.rs | 76 ++++++++++++++------- crates/lavina-core/src/repo/mod.rs | 5 +- crates/lavina-core/src/repo/room.rs | 33 +++++++++ crates/lavina-core/src/repo/user.rs | 30 ++++++++ crates/lavina-core/src/room.rs | 102 +++++++++++++++++----------- crates/projection-irc/src/lib.rs | 2 +- crates/projection-irc/tests/lib.rs | 102 +++++++++++++++++++++++++++- crates/projection-xmpp/src/lib.rs | 2 +- crates/projection-xmpp/tests/lib.rs | 2 +- src/main.rs | 2 +- 10 files changed, 284 insertions(+), 72 deletions(-) create mode 100644 crates/lavina-core/src/repo/room.rs create mode 100644 crates/lavina-core/src/repo/user.rs diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index eec22f8..0486808 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -7,16 +7,16 @@ //! //! A player actor is a serial handler of commands from a single player. It is preferable to run all per-player validations in the player actor, //! so that they don't overload the room actor. -use std::{ - collections::{HashMap, HashSet}, - sync::{Arc, RwLock}, -}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use prometheus::{IntGauge, Registry as MetricsRegistry}; use serde::Serialize; use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::RwLock; use crate::prelude::*; +use crate::repo::Storage; use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry}; use crate::table::{AnonTable, Key as AnonKey}; @@ -208,36 +208,41 @@ pub enum Updates { #[derive(Clone)] pub struct PlayerRegistry(Arc>); impl PlayerRegistry { - pub fn empty(room_registry: RoomRegistry, metrics: &mut MetricsRegistry) -> Result { + pub fn empty( + room_registry: RoomRegistry, + storage: Storage, + metrics: &mut MetricsRegistry, + ) -> Result { let metric_active_players = IntGauge::new("chat_players_active", "Number of alive player actors")?; metrics.register(Box::new(metric_active_players.clone()))?; let inner = PlayerRegistryInner { room_registry, + storage, players: HashMap::new(), metric_active_players, }; Ok(PlayerRegistry(Arc::new(RwLock::new(inner)))) } - pub async fn get_or_create_player(&mut self, id: PlayerId) -> PlayerHandle { - let mut inner = self.0.write().unwrap(); - if let Some((handle, _)) = inner.players.get(&id) { + pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle { + let mut inner = self.0.write().await; + if let Some((handle, _)) = inner.players.get(id) { handle.clone() } else { - let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone()); - inner.players.insert(id, (handle.clone(), fiber)); + let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone(), inner.storage.clone()).await; + inner.players.insert(id.clone(), (handle.clone(), fiber)); inner.metric_active_players.inc(); handle } } - pub async fn connect_to_player(&mut self, id: PlayerId) -> PlayerConnection { - let player_handle = self.get_or_create_player(id).await; + pub async fn connect_to_player(&mut self, id: &PlayerId) -> PlayerConnection { + let player_handle = self.get_or_launch_player(id).await; player_handle.subscribe().await } pub async fn shutdown_all(&mut self) -> Result<()> { - let mut inner = self.0.write().unwrap(); + let mut inner = self.0.write().await; for (i, (k, j)) in inner.players.drain() { k.send(ActorCommand::Stop).await; drop(k); @@ -252,6 +257,8 @@ impl PlayerRegistry { /// The player registry state representation. struct PlayerRegistryInner { room_registry: RoomRegistry, + storage: Storage, + /// Active player actors. players: HashMap)>, metric_active_players: IntGauge, } @@ -259,32 +266,49 @@ struct PlayerRegistryInner { /// Player actor inner state representation. struct Player { player_id: PlayerId, + storage_id: u32, connections: AnonTable>, my_rooms: HashMap, banned_from: HashSet, rx: Receiver, handle: PlayerHandle, rooms: RoomRegistry, + storage: Storage, } impl Player { - fn launch(player_id: PlayerId, rooms: RoomRegistry) -> (PlayerHandle, JoinHandle) { + async fn launch(player_id: PlayerId, rooms: RoomRegistry, storage: Storage) -> (PlayerHandle, JoinHandle) { let (tx, rx) = channel(32); let handle = PlayerHandle { tx }; let handle_clone = handle.clone(); + let storage_id = storage.retrieve_user_id_by_name(player_id.as_inner()).await.unwrap().unwrap(); let player = Player { player_id, + storage_id, + // connections are empty when the actor is just started connections: AnonTable::new(), + // room handlers will be loaded later in the started task my_rooms: HashMap::new(), - banned_from: HashSet::from([RoomId::from("Empty").unwrap()]), + // TODO implement and load bans + banned_from: HashSet::new(), rx, handle, rooms, + storage, }; let fiber = tokio::task::spawn(player.main_loop()); (handle_clone, fiber) } async fn main_loop(mut self) -> Self { + let rooms = self.storage.get_rooms_of_a_user(self.storage_id).await.unwrap(); + for room_id in rooms { + let room = self.rooms.get_room(&room_id).await; + if let Some(room) = room { + self.my_rooms.insert(room_id, room); + } else { + tracing::error!("Room #{room_id:?} not found"); + } + } while let Some(cmd) = self.rx.recv().await { match cmd { ActorCommand::AddConnection { sender, promise } => { @@ -372,7 +396,8 @@ impl Player { todo!(); } }; - room.subscribe(self.player_id.clone(), self.handle.clone()).await; + room.add_member(&self.player_id, self.storage_id).await; + room.subscribe(&self.player_id, self.handle.clone()).await; self.my_rooms.insert(room_id.clone(), room.clone()); let room_info = room.get_room_info().await; let update = Updates::RoomJoined { @@ -387,6 +412,7 @@ impl Player { let room = self.my_rooms.remove(&room_id); if let Some(room) = room { room.unsubscribe(&self.player_id).await; + room.remove_member(&self.player_id, self.storage_id).await; } let update = Updates::RoomLeft { room_id, @@ -396,12 +422,11 @@ impl Player { } async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) { - let room = self.rooms.get_room(&room_id).await; - if let Some(room) = room { - room.send_message(self.player_id.clone(), body.clone()).await; - } else { + let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("no room found"); - } + return; + }; + room.send_message(&self.player_id, body.clone()).await; let update = Updates::NewMessage { room_id, author_id: self.player_id.clone(), @@ -411,12 +436,11 @@ impl Player { } async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) { - let room = self.rooms.get_room(&room_id).await; - if let Some(mut room) = room { - room.set_topic(self.player_id.clone(), new_topic.clone()).await; - } else { + let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("no room found"); - } + return; + }; + room.set_topic(&self.player_id, new_topic.clone()).await; let update = Updates::RoomTopicChanged { room_id, new_topic }; self.broadcast_update(update, connection_id).await; } diff --git a/crates/lavina-core/src/repo/mod.rs b/crates/lavina-core/src/repo/mod.rs index e9eee6c..e8e3854 100644 --- a/crates/lavina-core/src/repo/mod.rs +++ b/crates/lavina-core/src/repo/mod.rs @@ -11,6 +11,9 @@ use tokio::sync::Mutex; use crate::prelude::*; +mod room; +mod user; + #[derive(Deserialize, Debug, Clone)] pub struct StorageConfig { pub db_path: String, @@ -48,7 +51,7 @@ impl Storage { Ok(res) } - pub async fn retrieve_room_by_name(&mut self, name: &str) -> Result> { + pub async fn retrieve_room_by_name(&self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( "select id, name, topic, message_count diff --git a/crates/lavina-core/src/repo/room.rs b/crates/lavina-core/src/repo/room.rs new file mode 100644 index 0000000..b568b9d --- /dev/null +++ b/crates/lavina-core/src/repo/room.rs @@ -0,0 +1,33 @@ +use anyhow::Result; + +use crate::repo::Storage; + +impl Storage { + pub async fn add_room_member(&self, room_id: u32, player_id: u32) -> Result<()> { + let mut executor = self.conn.lock().await; + sqlx::query( + "insert into memberships(user_id, room_id, status) + values (?, ?, 1);", + ) + .bind(player_id) + .bind(room_id) + .execute(&mut *executor) + .await?; + + Ok(()) + } + + pub async fn remove_room_member(&self, room_id: u32, player_id: u32) -> Result<()> { + let mut executor = self.conn.lock().await; + sqlx::query( + "delete from memberships + where user_id = ? and room_id = ?;", + ) + .bind(player_id) + .bind(room_id) + .execute(&mut *executor) + .await?; + + Ok(()) + } +} diff --git a/crates/lavina-core/src/repo/user.rs b/crates/lavina-core/src/repo/user.rs new file mode 100644 index 0000000..d836b8f --- /dev/null +++ b/crates/lavina-core/src/repo/user.rs @@ -0,0 +1,30 @@ +use anyhow::Result; + +use crate::repo::Storage; +use crate::room::RoomId; + +impl Storage { + pub async fn retrieve_user_id_by_name(&self, name: &str) -> Result> { + let mut executor = self.conn.lock().await; + let res: Option<(u32,)> = sqlx::query_as("select u.id from users u where u.name = ?;") + .bind(name) + .fetch_optional(&mut *executor) + .await?; + + Ok(res.map(|(id,)| id)) + } + + pub async fn get_rooms_of_a_user(&self, user_id: u32) -> Result> { + let mut executor = self.conn.lock().await; + let res: Vec<(String,)> = sqlx::query_as( + "select r.name + from memberships m inner join rooms r on m.room_id = r.id + where m.user_id = ?;", + ) + .bind(user_id) + .fetch_all(&mut *executor) + .await?; + + res.into_iter().map(|(room_id,)| RoomId::from(room_id)).collect() + } +} diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index 04fdbb1..dbd14fd 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -1,4 +1,5 @@ //! Domain of rooms — chats with multiple participants. +use std::collections::HashSet; use std::{collections::HashMap, hash::Hash, sync::Arc}; use prometheus::{IntGauge, Registry as MetricRegistry}; @@ -48,27 +49,9 @@ impl RoomRegistry { pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result { let mut inner = self.0.write().await; - if let Some(room_handle) = inner.rooms.get(&room_id) { - // room was already loaded into memory - log::debug!("Room {} was loaded already", &room_id.0); + if let Some(room_handle) = inner.get_or_load_room(&room_id).await? { Ok(room_handle.clone()) - } else if let Some(stored_room) = inner.storage.retrieve_room_by_name(&*room_id.0).await? { - // room exists, but was not loaded - log::debug!("Loading room {}...", &room_id.0); - let room = Room { - storage_id: stored_room.id, - room_id: room_id.clone(), - subscriptions: HashMap::new(), // TODO figure out how to populate subscriptions - topic: stored_room.topic.into(), - message_count: stored_room.message_count, - storage: inner.storage.clone(), - }; - let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room))); - inner.rooms.insert(room_id, room_handle.clone()); - inner.metric_active_rooms.inc(); - Ok(room_handle) } else { - // room does not exist, create it and load log::debug!("Creating room {}...", &room_id.0); let topic = "New room"; let id = inner.storage.create_new_room(&*room_id.0, &*topic).await?; @@ -76,6 +59,7 @@ impl RoomRegistry { storage_id: id, room_id: room_id.clone(), subscriptions: HashMap::new(), + members: HashSet::new(), topic: topic.into(), message_count: 0, storage: inner.storage.clone(), @@ -88,9 +72,8 @@ impl RoomRegistry { } pub async fn get_room(&self, room_id: &RoomId) -> Option { - let inner = self.0.read().await; - let res = inner.rooms.get(room_id); - res.map(|r| r.clone()) + let mut inner = self.0.write().await; + inner.get_or_load_room(room_id).await.unwrap() } pub async fn get_all_rooms(&self) -> Vec { @@ -113,17 +96,66 @@ struct RoomRegistryInner { storage: Storage, } +impl RoomRegistryInner { + async fn get_or_load_room(&mut self, room_id: &RoomId) -> Result> { + if let Some(room_handle) = self.rooms.get(room_id) { + log::debug!("Room {} was loaded already", &room_id.0); + Ok(Some(room_handle.clone())) + } else if let Some(stored_room) = self.storage.retrieve_room_by_name(&*room_id.0).await? { + log::debug!("Loading room {}...", &room_id.0); + let room = Room { + storage_id: stored_room.id, + room_id: room_id.clone(), + subscriptions: HashMap::new(), + members: HashSet::new(), // TODO load members from storage + topic: stored_room.topic.into(), + message_count: stored_room.message_count, + storage: self.storage.clone(), + }; + let room_handle = RoomHandle(Arc::new(AsyncRwLock::new(room))); + self.rooms.insert(room_id.clone(), room_handle.clone()); + self.metric_active_rooms.inc(); + Ok(Some(room_handle)) + } else { + tracing::debug!("Room {} does not exist", &room_id.0); + Ok(None) + } + } +} + #[derive(Clone)] pub struct RoomHandle(Arc>); impl RoomHandle { - pub async fn subscribe(&self, player_id: PlayerId, player_handle: PlayerHandle) { + pub async fn subscribe(&self, player_id: &PlayerId, player_handle: PlayerHandle) { let mut lock = self.0.write().await; - lock.add_subscriber(player_id, player_handle).await; + tracing::info!("Adding a subscriber to a room"); + lock.subscriptions.insert(player_id.clone(), player_handle); + } + + pub async fn add_member(&self, player_id: &PlayerId, player_storage_id: u32) { + let mut lock = self.0.write().await; + tracing::info!("Adding a new member to a room"); + let room_storage_id = lock.storage_id; + lock.storage.add_room_member(room_storage_id, player_storage_id).await.unwrap(); + lock.members.insert(player_id.clone()); + let update = Updates::RoomJoined { + room_id: lock.room_id.clone(), + new_member_id: player_id.clone(), + }; + lock.broadcast_update(update, player_id).await; } pub async fn unsubscribe(&self, player_id: &PlayerId) { let mut lock = self.0.write().await; lock.subscriptions.remove(player_id); + } + + pub async fn remove_member(&self, player_id: &PlayerId, player_storage_id: u32) { + let mut lock = self.0.write().await; + tracing::info!("Removing a member from a room"); + let room_storage_id = lock.storage_id; + lock.storage.remove_room_member(room_storage_id, player_storage_id).await.unwrap(); + lock.members.remove(player_id); let update = Updates::RoomLeft { room_id: lock.room_id.clone(), former_member_id: player_id.clone(), @@ -131,7 +163,7 @@ impl RoomHandle { lock.broadcast_update(update, player_id).await; } - pub async fn send_message(&self, player_id: PlayerId, body: Str) { + pub async fn send_message(&self, player_id: &PlayerId, body: Str) { let mut lock = self.0.write().await; let res = lock.send_message(player_id, body).await; if let Err(err) = res { @@ -148,14 +180,14 @@ impl RoomHandle { } } - pub async fn set_topic(&mut self, changer_id: PlayerId, new_topic: Str) { + pub async fn set_topic(&self, changer_id: &PlayerId, new_topic: Str) { let mut lock = self.0.write().await; lock.topic = new_topic.clone(); let update = Updates::RoomTopicChanged { room_id: lock.room_id.clone(), new_topic: new_topic.clone(), }; - lock.broadcast_update(update, &changer_id).await; + lock.broadcast_update(update, changer_id).await; } } @@ -166,23 +198,15 @@ struct Room { room_id: RoomId, /// Player actors on the local node which are subscribed to this room's updates. subscriptions: HashMap, + /// Members of the room. + members: HashSet, /// The total number of messages. Used to calculate the id of the new message. message_count: u32, topic: Str, storage: Storage, } impl Room { - async fn add_subscriber(&mut self, player_id: PlayerId, player_handle: PlayerHandle) { - tracing::info!("Adding a subscriber to room"); - self.subscriptions.insert(player_id.clone(), player_handle); - let update = Updates::RoomJoined { - room_id: self.room_id.clone(), - new_member_id: player_id.clone(), - }; - self.broadcast_update(update, &player_id).await; - } - - async fn send_message(&mut self, author_id: PlayerId, body: Str) -> Result<()> { + async fn send_message(&mut self, author_id: &PlayerId, body: Str) -> Result<()> { tracing::info!("Adding a message to room"); self.storage.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner()).await?; self.message_count += 1; @@ -191,7 +215,7 @@ impl Room { author_id: author_id.clone(), body, }; - self.broadcast_update(update, &author_id).await; + self.broadcast_update(update, author_id).await; Ok(()) } diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index ee47f9b..e52e92a 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -392,7 +392,7 @@ async fn handle_registered_socket<'a>( log::info!("Handling registered user: {user:?}"); let player_id = PlayerId::from(user.nickname.clone())?; - let mut connection = players.connect_to_player(player_id.clone()).await; + let mut connection = players.connect_to_player(&player_id).await; let text: Str = format!("Welcome to {} Server", &config.server_name).into(); ServerMessage { diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 36b9040..3618467 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -111,7 +111,36 @@ impl TestServer { }) .await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); + let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); + Ok(TestServer { + metrics, + storage, + rooms, + players, + server, + }) + } + + async fn reboot(mut self) -> Result { + let config = ServerConfig { + listen_on: "127.0.0.1:0".parse().unwrap(), + server_name: "testserver".into(), + }; + let TestServer { + mut metrics, + mut storage, + rooms, + mut players, + server, + } = self; + server.terminate().await?; + players.shutdown_all().await.unwrap(); + drop(players); + drop(rooms); + let mut metrics = MetricsRegistry::new(); + let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); Ok(TestServer { metrics, @@ -152,6 +181,76 @@ async fn scenario_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_join_and_reboot() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + // Open a connection and join a channel + + s.send("PASS password").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect_server_introduction("tester").await?; + s.expect_nothing().await?; + s.send("JOIN #test").await?; + s.expect(":tester JOIN #test").await?; + s.expect(":testserver 332 tester #test :New room").await?; + s.expect(":testserver 353 tester = #test :tester").await?; + s.expect(":testserver 366 tester #test :End of /NAMES list").await?; + s.send("PRIVMSG #test :Hello").await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + stream.shutdown().await?; + + // Open a new connection and expect to be force-joined to the channel + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + async fn test(s: &mut TestScope<'_>) -> Result<()> { + s.send("PASS password").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect_server_introduction("tester").await?; + s.expect(":tester JOIN #test").await?; + s.expect(":testserver 332 tester #test :New room").await?; + s.expect(":testserver 353 tester = #test :tester").await?; + s.expect(":testserver 366 tester #test :End of /NAMES list").await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + Ok(()) + } + test(&mut s).await?; + stream.shutdown().await?; + + // Reboot the server + + let server = server.reboot().await?; + + // Open a new connection and expect to be force-joined to the channel + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + test(&mut s).await?; + stream.shutdown().await?; + + // wrap up + + server.server.terminate().await?; + Ok(()) +} + #[tokio::test] async fn scenario_force_join_msg() -> Result<()> { let mut server = TestServer::start().await?; @@ -407,7 +506,6 @@ async fn scenario_cap_sasl_fail() -> Result<()> { #[tokio::test] async fn terminate_socket_scenario() -> Result<()> { let mut server = TestServer::start().await?; - let address: SocketAddr = ("127.0.0.1:0".parse().unwrap()); // test scenario diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 0da3019..30e0a3c 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -207,7 +207,7 @@ async fn handle_socket( authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage, &hostname) => { match authenticated { Ok(authenticated) => { - let mut connection = players.connect_to_player(authenticated.player_id.clone()).await; + let mut connection = players.connect_to_player(&authenticated.player_id).await; socket_final( &mut xml_reader, &mut xml_writer, diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index cc7e645..29d0368 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -143,7 +143,7 @@ impl TestServer { }) .await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); Ok(TestServer { metrics, diff --git a/src/main.rs b/src/main.rs index 8111074..0d03a89 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,7 +52,7 @@ async fn main() -> Result<()> { let mut metrics = MetricsRegistry::new(); let storage = Storage::open(storage_config).await?; let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; - let mut players = PlayerRegistry::empty(rooms.clone(), &mut metrics)?; + let mut players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics)?; let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), rooms.clone(), storage.clone()).await?; let irc = projection_irc::launch( irc_config, From 757d7c56657f5ec6a291d2cbae5a7c7ee90b32bc Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Mon, 15 Apr 2024 09:12:23 +0000 Subject: [PATCH 12/37] persistent room topics (#50) Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/50 --- crates/lavina-core/src/repo/room.rs | 15 +++++++++++++++ crates/lavina-core/src/room.rs | 2 ++ 2 files changed, 17 insertions(+) diff --git a/crates/lavina-core/src/repo/room.rs b/crates/lavina-core/src/repo/room.rs index b568b9d..96b89f2 100644 --- a/crates/lavina-core/src/repo/room.rs +++ b/crates/lavina-core/src/repo/room.rs @@ -30,4 +30,19 @@ impl Storage { Ok(()) } + + pub async fn set_room_topic(&mut self, id: u32, topic: &str) -> Result<()> { + let mut executor = self.conn.lock().await; + sqlx::query( + "update rooms + set topic = ? + where id = ?;", + ) + .bind(topic) + .bind(id) + .fetch_optional(&mut *executor) + .await?; + + Ok(()) + } } diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index dbd14fd..a5e2dab 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -182,7 +182,9 @@ impl RoomHandle { pub async fn set_topic(&self, changer_id: &PlayerId, new_topic: Str) { let mut lock = self.0.write().await; + let storage_id = lock.storage_id; lock.topic = new_topic.clone(); + lock.storage.set_room_topic(storage_id, &new_topic).await.unwrap(); let update = Updates::RoomTopicChanged { room_id: lock.room_id.clone(), new_topic: new_topic.clone(), From 6d493d83a3033ff86cbb7defa3bcc4d7940cd811 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Mon, 15 Apr 2024 18:18:51 +0200 Subject: [PATCH 13/37] xmpp: use the Jid type in IQs' to and from fields, separate presence handling --- crates/projection-xmpp/src/iq.rs | 36 ++++++++--- crates/projection-xmpp/src/presence.rs | 87 +++++++++++++++----------- crates/proto-xmpp/src/client.rs | 20 +++--- 3 files changed, 90 insertions(+), 53 deletions(-) diff --git a/crates/projection-xmpp/src/iq.rs b/crates/projection-xmpp/src/iq.rs index 6766e19..eebbc4d 100644 --- a/crates/projection-xmpp/src/iq.rs +++ b/crates/projection-xmpp/src/iq.rs @@ -3,7 +3,7 @@ use quick_xml::events::Event; use lavina_core::room::RoomRegistry; -use proto_xmpp::bind::{BindResponse, Jid, Name, Resource, Server}; +use proto_xmpp::bind::{BindResponse, Jid, Name, Server}; use proto_xmpp::client::{Iq, IqError, IqErrorType, IqType}; use proto_xmpp::disco::{Feature, Identity, InfoQuery, Item, ItemQuery}; use proto_xmpp::roster::RosterQuery; @@ -17,7 +17,7 @@ use proto_xmpp::xml::ToXml; impl<'a> XmppConnection<'a> { pub async fn handle_iq(&self, output: &mut Vec>, iq: Iq) { match iq.body { - IqClientBody::Bind(b) => { + IqClientBody::Bind(_) => { let req = Iq { from: None, id: iq.id, @@ -52,7 +52,7 @@ impl<'a> XmppConnection<'a> { req.serialize(output); } IqClientBody::DiscoInfo(info) => { - let response = self.disco_info(iq.to.as_deref(), &info); + let response = self.disco_info(iq.to.as_ref(), &info); let req = Iq { from: iq.to, id: iq.id, @@ -63,7 +63,7 @@ impl<'a> XmppConnection<'a> { req.serialize(output); } IqClientBody::DiscoItem(item) => { - let response = self.disco_items(iq.to.as_deref(), &item, self.rooms).await; + let response = self.disco_items(iq.to.as_ref(), &item, self.rooms).await; let req = Iq { from: iq.to, id: iq.id, @@ -88,12 +88,16 @@ impl<'a> XmppConnection<'a> { } } - fn disco_info(&self, to: Option<&str>, req: &InfoQuery) -> InfoQuery { + fn disco_info(&self, to: Option<&Jid>, req: &InfoQuery) -> InfoQuery { let identity; let feature; match to { - Some(r) if r == &*self.hostname => { + Some(Jid { + name: None, + server, + resource: None, + }) if server.0 == self.hostname => { identity = vec![Identity { category: "server".into(), name: None, @@ -106,7 +110,11 @@ impl<'a> XmppConnection<'a> { Feature::new("presence"), ] } - Some(r) if r == &*self.hostname_rooms => { + Some(Jid { + name: None, + server, + resource: None, + }) if server.0 == self.hostname_rooms => { identity = vec![Identity { category: "conference".into(), name: Some("Chat rooms".into()), @@ -130,9 +138,13 @@ impl<'a> XmppConnection<'a> { } } - async fn disco_items(&self, to: Option<&str>, req: &ItemQuery, rooms: &RoomRegistry) -> ItemQuery { + async fn disco_items(&self, to: Option<&Jid>, req: &ItemQuery, rooms: &RoomRegistry) -> ItemQuery { let item = match to { - Some(r) if r == &*self.hostname => { + Some(Jid { + name: None, + server, + resource: None, + }) if server.0 == self.hostname => { vec![Item { jid: Jid { name: None, @@ -143,7 +155,11 @@ impl<'a> XmppConnection<'a> { node: None, }] } - Some(r) if r == &*self.hostname_rooms => { + Some(Jid { + name: None, + server, + resource: None, + }) if server.0 == self.hostname_rooms => { let room_list = rooms.get_all_rooms().await; room_list .into_iter() diff --git a/crates/projection-xmpp/src/presence.rs b/crates/projection-xmpp/src/presence.rs index 6f9540e..82ccb61 100644 --- a/crates/projection-xmpp/src/presence.rs +++ b/crates/projection-xmpp/src/presence.rs @@ -4,7 +4,7 @@ use quick_xml::events::Event; use lavina_core::prelude::*; use lavina_core::room::RoomId; -use proto_xmpp::bind::{Jid, Server}; +use proto_xmpp::bind::{Jid, Name, Server}; use proto_xmpp::client::Presence; use proto_xmpp::xml::{Ignore, ToXml}; @@ -12,42 +12,59 @@ use crate::XmppConnection; impl<'a> XmppConnection<'a> { pub async fn handle_presence(&mut self, output: &mut Vec>, p: Presence) -> Result<()> { - let response = if p.to.is_none() { - Presence::<()> { - to: Some(Jid { - name: Some(self.user.xmpp_name.clone()), - server: Server(self.hostname.clone()), - resource: Some(self.user.xmpp_resource.clone()), - }), - from: Some(Jid { - name: Some(self.user.xmpp_name.clone()), - server: Server(self.hostname.clone()), - resource: Some(self.user.xmpp_resource.clone()), - }), - ..Default::default() + match p.to { + None => { + self.self_presence(output).await; } - } else if let Some(Jid { - name: Some(name), - server, - resource: Some(resource), - }) = p.to - { - let a = self.user_handle.join_room(RoomId::from(name.0.clone())?).await?; - Presence::<()> { - to: Some(Jid { - name: Some(self.user.xmpp_name.clone()), - server: Server(self.hostname.clone()), - resource: Some(self.user.xmpp_resource.clone()), - }), - from: Some(Jid { - name: Some(name.clone()), - server: Server(self.hostname_rooms.clone()), - resource: Some(self.user.xmpp_muc_name.clone()), - }), - ..Default::default() + Some(Jid { + name: Some(name), + server, + // resources in MUCs are members' personas – not implemented (yet?) + resource: Some(_), + }) if server.0 == self.hostname_rooms => { + self.muc_presence(name, output).await?; } - } else { - Presence::<()>::default() + _ => { + // TODO other presence cases + let response = Presence::<()>::default(); + response.serialize(output); + } + } + Ok(()) + } + + async fn self_presence(&mut self, output: &mut Vec>) { + let response = Presence::<()> { + to: Some(Jid { + name: Some(self.user.xmpp_name.clone()), + server: Server(self.hostname.clone()), + resource: Some(self.user.xmpp_resource.clone()), + }), + from: Some(Jid { + name: Some(self.user.xmpp_name.clone()), + server: Server(self.hostname.clone()), + resource: Some(self.user.xmpp_resource.clone()), + }), + ..Default::default() + }; + response.serialize(output); + } + + async fn muc_presence(&mut self, name: Name, output: &mut Vec>) -> Result<()> { + let a = self.user_handle.join_room(RoomId::from(name.0.clone())?).await?; + // TODO handle bans + let response = Presence::<()> { + to: Some(Jid { + name: Some(self.user.xmpp_name.clone()), + server: Server(self.hostname.clone()), + resource: Some(self.user.xmpp_resource.clone()), + }), + from: Some(Jid { + name: Some(name.clone()), + server: Server(self.hostname_rooms.clone()), + resource: Some(self.user.xmpp_muc_name.clone()), + }), + ..Default::default() }; response.serialize(output); Ok(()) diff --git a/crates/proto-xmpp/src/client.rs b/crates/proto-xmpp/src/client.rs index 4943283..85b3979 100644 --- a/crates/proto-xmpp/src/client.rs +++ b/crates/proto-xmpp/src/client.rs @@ -295,9 +295,9 @@ impl ToXml for IqError { #[derive(PartialEq, Eq, Debug)] pub struct Iq { - pub from: Option, + pub from: Option, pub id: String, - pub to: Option, + pub to: Option, pub r#type: IqType, pub body: T, } @@ -323,9 +323,9 @@ enum IqParserInner { Final(IqParserState), } struct IqParserState { - pub from: Option, + pub from: Option, pub id: Option, - pub to: Option, + pub to: Option, pub r#type: Option, pub body: Option, } @@ -348,13 +348,15 @@ impl Parser for IqParser { let attr = fail_fast!(attr); if attr.key.0 == b"from" { let value = fail_fast!(std::str::from_utf8(&*attr.value)); - state.from = Some(value.to_string()) + let value = fail_fast!(Jid::from_string(value)); + state.from = Some(value) } else if attr.key.0 == b"id" { let value = fail_fast!(std::str::from_utf8(&*attr.value)); state.id = Some(value.to_string()) } else if attr.key.0 == b"to" { let value = fail_fast!(std::str::from_utf8(&*attr.value)); - state.to = Some(value.to_string()) + let value = fail_fast!(Jid::from_string(value)); + state.to = Some(value) } else if attr.key.0 == b"type" { let value = fail_fast!(IqType::from_str(&*attr.value)); state.r#type = Some(value); @@ -431,15 +433,17 @@ impl ToXml for Iq { let mut start = BytesStart::new(start); let mut attrs = vec![]; if let Some(ref from) = self.from { + let value = from.to_string().into_bytes(); attrs.push(Attribute { key: QName(b"from"), - value: from.as_bytes().into(), + value: value.into(), }); }; if let Some(ref to) = self.to { + let value = to.to_string().into_bytes(); attrs.push(Attribute { key: QName(b"to"), - value: to.as_bytes().into(), + value: value.into(), }); } attrs.push(Attribute { From 6bba699d87b60d3b09d9ebeab4c7afbaa3cf67e4 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Mon, 15 Apr 2024 23:08:43 +0200 Subject: [PATCH 14/37] xmpp: disco-info iq for rooms --- crates/projection-xmpp/src/iq.rs | 68 ++++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/crates/projection-xmpp/src/iq.rs b/crates/projection-xmpp/src/iq.rs index eebbc4d..0031c36 100644 --- a/crates/projection-xmpp/src/iq.rs +++ b/crates/projection-xmpp/src/iq.rs @@ -2,7 +2,7 @@ use quick_xml::events::Event; -use lavina_core::room::RoomRegistry; +use lavina_core::room::{RoomId, RoomRegistry}; use proto_xmpp::bind::{BindResponse, Jid, Name, Server}; use proto_xmpp::client::{Iq, IqError, IqErrorType, IqType}; use proto_xmpp::disco::{Feature, Identity, InfoQuery, Item, ItemQuery}; @@ -52,15 +52,29 @@ impl<'a> XmppConnection<'a> { req.serialize(output); } IqClientBody::DiscoInfo(info) => { - let response = self.disco_info(iq.to.as_ref(), &info); - let req = Iq { - from: iq.to, - id: iq.id, - to: None, - r#type: IqType::Result, - body: response, - }; - req.serialize(output); + let response = self.disco_info(iq.to.as_ref(), &info).await; + match response { + Ok(response) => { + let req = Iq { + from: iq.to, + id: iq.id, + to: None, + r#type: IqType::Result, + body: response, + }; + req.serialize(output); + } + Err(response) => { + let req = Iq { + from: iq.to, + id: iq.id, + to: None, + r#type: IqType::Error, + body: response, + }; + req.serialize(output); + } + } } IqClientBody::DiscoItem(item) => { let response = self.disco_items(iq.to.as_ref(), &item, self.rooms).await; @@ -88,7 +102,7 @@ impl<'a> XmppConnection<'a> { } } - fn disco_info(&self, to: Option<&Jid>, req: &InfoQuery) -> InfoQuery { + async fn disco_info(&self, to: Option<&Jid>, req: &InfoQuery) -> Result { let identity; let feature; @@ -126,16 +140,44 @@ impl<'a> XmppConnection<'a> { Feature::new("http://jabber.org/protocol/muc"), ] } + Some(Jid { + name: Some(room_name), + server, + resource: None, + }) if server.0 == self.hostname_rooms => { + let room_id = RoomId::from(room_name.0.clone()).unwrap(); + let Some(_) = self.rooms.get_room(&room_id).await else { + // TODO should return item-not-found + // example: + // + // + // Conference room does not exist + // + return Err(IqError { + r#type: IqErrorType::Cancel, + }); + }; + identity = vec![Identity { + category: "conference".into(), + name: Some(room_id.into_inner().to_string()), + r#type: "text".into(), + }]; + feature = vec![ + Feature::new("http://jabber.org/protocol/disco#info"), + Feature::new("http://jabber.org/protocol/disco#items"), + Feature::new("http://jabber.org/protocol/muc"), + ] + } _ => { identity = vec![]; feature = vec![]; } }; - InfoQuery { + Ok(InfoQuery { node: None, identity, feature, - } + }) } async fn disco_items(&self, to: Option<&Jid>, req: &ItemQuery, rooms: &RoomRegistry) -> ItemQuery { From 048660624d38aa2165ee2ddf7de319d653718097 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Tue, 16 Apr 2024 11:35:14 +0000 Subject: [PATCH 15/37] irc: support registration with different order of NICK/USER/CAP END commands (#51) Resolves #33 Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/51 --- crates/projection-irc/src/lib.rs | 466 +++++++++++++++-------------- crates/projection-irc/tests/lib.rs | 39 +++ 2 files changed, 285 insertions(+), 220 deletions(-) diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index e52e92a..7f1b49e 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -86,6 +86,249 @@ async fn handle_socket( Ok(()) } +struct RegistrationState { + /// The last received `NICK` message. + future_nickname: Option, + /// The last received `USER` message. + future_username: Option<(Str, Str)>, + enabled_capabilities: Capabilities, + /// `CAP LS` or `CAP REQ` was received, but not `CAP END`. + cap_negotiation_in_progress: bool, + /// The last received `PASS` message. + pass: Option, + authentication_started: bool, + validated_user: Option, +} + +impl RegistrationState { + fn new() -> RegistrationState { + RegistrationState { + future_nickname: None, + future_username: None, + enabled_capabilities: Capabilities::None, + cap_negotiation_in_progress: false, + pass: None, + authentication_started: false, + validated_user: None, + } + } + + /// Handle an incoming message from the client during the registration process. + /// + /// Returns `Some` if the user is fully registered, `None` if the registration is still in progress. + async fn handle_msg( + &mut self, + msg: ClientMessage, + writer: &mut BufWriter>, + storage: &mut Storage, + config: &ServerConfig, + ) -> Result> { + match msg { + ClientMessage::Pass { password } => { + self.pass = Some(password); + Ok(None) + } + ClientMessage::Capability { subcommand } => match subcommand { + CapabilitySubcommand::List { code: _ } => { + self.cap_negotiation_in_progress = true; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::Cap { + target: self.future_nickname.clone().unwrap_or_else(|| "*".into()), + subcmd: CapSubBody::Ls("sasl=PLAIN".into()), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } + CapabilitySubcommand::Req(caps) => { + self.cap_negotiation_in_progress = true; + let mut acked = vec![]; + let mut naked = vec![]; + for cap in caps { + if &*cap.name == "sasl" { + if cap.to_disable { + self.enabled_capabilities &= !Capabilities::Sasl; + } else { + self.enabled_capabilities |= Capabilities::Sasl; + } + acked.push(cap); + } else { + naked.push(cap); + } + } + let mut ack_body = String::new(); + for cap in acked { + if cap.to_disable { + ack_body.push('-'); + } + ack_body += &*cap.name; + } + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::Cap { + target: self.future_nickname.clone().unwrap_or_else(|| "*".into()), + subcmd: CapSubBody::Ack(ack_body.into()), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } + CapabilitySubcommand::End => { + let Some((ref username, ref realname)) = self.future_username else { + self.cap_negotiation_in_progress = false; + return Ok(None); + }; + let Some(nickname) = self.future_nickname.clone() else { + self.cap_negotiation_in_progress = false; + return Ok(None); + }; + let username = username.clone(); + let realname = realname.clone(); + let candidate_user = RegisteredUser { + nickname: nickname.clone(), + username, + realname, + }; + self.finalize_auth(candidate_user, writer, storage, config).await + } + }, + ClientMessage::Nick { nickname } => { + if self.cap_negotiation_in_progress { + self.future_nickname = Some(nickname); + Ok(None) + } else if let Some((username, realname)) = &self.future_username.clone() { + let candidate_user = RegisteredUser { + nickname: nickname.clone(), + username: username.clone(), + realname: realname.clone(), + }; + self.finalize_auth(candidate_user, writer, storage, config).await + } else { + self.future_nickname = Some(nickname); + Ok(None) + } + } + ClientMessage::User { username, realname } => { + if self.cap_negotiation_in_progress { + self.future_username = Some((username, realname)); + Ok(None) + } else if let Some(nickname) = self.future_nickname.clone() { + let candidate_user = RegisteredUser { + nickname: nickname.clone(), + username, + realname, + }; + self.finalize_auth(candidate_user, writer, storage, config).await + } else { + self.future_username = Some((username, realname)); + Ok(None) + } + } + ClientMessage::Authenticate(body) => { + if !self.authentication_started { + tracing::debug!("Received authentication request"); + if &*body == "PLAIN" { + tracing::debug!("Authentication request with method PLAIN"); + self.authentication_started = true; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::Authenticate("+".into()), + } + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } else { + let target = self.future_nickname.clone().unwrap_or_else(|| "*".into()); + sasl_fail_message(config.server_name.clone(), target, "Unsupported mechanism".into()) + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } + } else { + let body = AuthBody::from_str(body.as_bytes())?; + if let Err(e) = auth_user(storage, &body.login, &body.password).await { + tracing::warn!("Authentication failed: {:?}", e); + let target = self.future_nickname.clone().unwrap_or_else(|| "*".into()); + sasl_fail_message(config.server_name.clone(), target, "Bad credentials".into()) + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } else { + let login: Str = body.login.into(); + self.validated_user = Some(login.clone()); + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N900LoggedIn { + nick: login.clone(), + address: login.clone(), + account: login.clone(), + message: format!("You are now logged in as {}", login).into(), + }, + } + .write_async(writer) + .await?; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N903SaslSuccess { + nick: login.clone(), + message: "SASL authentication successful".into(), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; + Ok(None) + } + } + + // TODO handle abortion of authentication + } + _ => Ok(None), + } + } + + async fn finalize_auth( + &mut self, + candidate_user: RegisteredUser, + writer: &mut BufWriter>, + storage: &mut Storage, + config: &ServerConfig, + ) -> Result> { + if self.enabled_capabilities.contains(Capabilities::Sasl) + && self.validated_user.as_ref() == Some(&candidate_user.nickname) + { + Ok(Some(candidate_user)) + } else { + let Some(candidate_password) = &self.pass else { + sasl_fail_message( + config.server_name.clone(), + candidate_user.nickname.clone(), + "User credentials was not provided".into(), + ) + .write_async(writer) + .await?; + writer.flush().await?; + return Ok(None); + }; + auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; + Ok(Some(candidate_user)) + } + } +} + async fn handle_registration<'a>( reader: &mut BufReader>, writer: &mut BufWriter>, @@ -94,14 +337,7 @@ async fn handle_registration<'a>( ) -> Result { let mut buffer = vec![]; - let mut future_nickname: Option = None; - let mut future_username: Option<(Str, Str)> = None; - let mut enabled_capabilities = Capabilities::None; - let mut cap_negotiation_in_progress = false; // if true, expect `CAP END` to complete registration - - let mut pass: Option = None; - let mut authentication_started = false; - let mut validated_user = None; + let mut state = RegistrationState::new(); let user = loop { let res = read_irc_message(reader, &mut buffer).await; @@ -132,218 +368,8 @@ async fn handle_registration<'a>( } }; tracing::debug!("Incoming IRC message: {msg:?}"); - match msg { - ClientMessage::Pass { password } => { - pass = Some(password); - } - ClientMessage::Capability { subcommand } => match subcommand { - CapabilitySubcommand::List { code: _ } => { - cap_negotiation_in_progress = true; - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::Cap { - target: future_nickname.clone().unwrap_or_else(|| "*".into()), - subcmd: CapSubBody::Ls("sasl=PLAIN".into()), - }, - } - .write_async(writer) - .await?; - writer.flush().await?; - } - CapabilitySubcommand::Req(caps) => { - cap_negotiation_in_progress = true; - let mut acked = vec![]; - let mut naked = vec![]; - for cap in caps { - if &*cap.name == "sasl" { - if cap.to_disable { - enabled_capabilities &= !Capabilities::Sasl; - } else { - enabled_capabilities |= Capabilities::Sasl; - } - acked.push(cap); - } else { - naked.push(cap); - } - } - let mut ack_body = String::new(); - for cap in acked { - if cap.to_disable { - ack_body.push('-'); - } - ack_body += &*cap.name; - } - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::Cap { - target: future_nickname.clone().unwrap_or_else(|| "*".into()), - subcmd: CapSubBody::Ack(ack_body.into()), - }, - } - .write_async(writer) - .await?; - writer.flush().await?; - } - CapabilitySubcommand::End => { - let Some((ref username, ref realname)) = future_username else { - todo!(); - }; - let Some(nickname) = future_nickname.clone() else { - todo!(); - }; - let username = username.clone(); - let realname = realname.clone(); - let candidate_user = RegisteredUser { - nickname: nickname.clone(), - username, - realname, - }; - if enabled_capabilities.contains(Capabilities::Sasl) - && validated_user.as_ref() == Some(&candidate_user.nickname) - { - break Ok(candidate_user); - } else { - let Some(candidate_password) = pass else { - sasl_fail_message( - config.server_name.clone(), - nickname.clone(), - "User credentials was not provided".into(), - ) - .write_async(writer) - .await?; - writer.flush().await?; - continue; - }; - auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; - break Ok(candidate_user); - } - } - }, - ClientMessage::Nick { nickname } => { - if cap_negotiation_in_progress { - future_nickname = Some(nickname); - } else if let Some((username, realname)) = future_username.clone() { - let candidate_user = RegisteredUser { - nickname: nickname.clone(), - username, - realname, - }; - let Some(candidate_password) = pass else { - sasl_fail_message( - config.server_name.clone(), - nickname.clone(), - "User credentials was not provided".into(), - ) - .write_async(writer) - .await?; - writer.flush().await?; - continue; - }; - auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; - break Ok(candidate_user); - } else { - future_nickname = Some(nickname); - } - } - ClientMessage::User { username, realname } => { - if cap_negotiation_in_progress { - future_username = Some((username, realname)); - } else if let Some(nickname) = future_nickname.clone() { - let candidate_user = RegisteredUser { - nickname: nickname.clone(), - username, - realname, - }; - let Some(candidate_password) = pass else { - sasl_fail_message( - config.server_name.clone(), - nickname.clone(), - "User credentials was not provided".into(), - ) - .write_async(writer) - .await?; - writer.flush().await?; - continue; - }; - auth_user(storage, &*candidate_user.nickname, &*candidate_password).await?; - break Ok(candidate_user); - } else { - future_username = Some((username, realname)); - } - } - ClientMessage::Authenticate(body) => { - if !authentication_started { - tracing::debug!("Received authentication request"); - if &*body == "PLAIN" { - tracing::debug!("Authentication request with method PLAIN"); - authentication_started = true; - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::Authenticate("+".into()), - } - .write_async(writer) - .await?; - writer.flush().await?; - } else { - if let Some(nickname) = future_nickname.clone() { - sasl_fail_message( - config.server_name.clone(), - nickname.clone(), - "Unsupported mechanism".into(), - ) - .write_async(writer) - .await?; - writer.flush().await?; - } else { - break Err(anyhow::Error::msg("Wrong authentication sequence")); - } - } - } else { - let body = AuthBody::from_str(body.as_bytes())?; - if let Err(e) = auth_user(storage, &body.login, &body.password).await { - tracing::warn!("Authentication failed: {:?}", e); - if let Some(nickname) = future_nickname.clone() { - sasl_fail_message(config.server_name.clone(), nickname.clone(), "Bad credentials".into()) - .write_async(writer) - .await?; - writer.flush().await?; - } else { - } - } else { - let login: Str = body.login.into(); - validated_user = Some(login.clone()); - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::N900LoggedIn { - nick: login.clone(), - address: login.clone(), - account: login.clone(), - message: format!("You are now logged in as {}", login).into(), - }, - } - .write_async(writer) - .await?; - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::N903SaslSuccess { - nick: login.clone(), - message: "SASL authentication successful".into(), - }, - } - .write_async(writer) - .await?; - writer.flush().await?; - } - } - - // TODO handle abortion of authentication - } - _ => {} + if let Some(user) = state.handle_msg(msg, writer, storage, config).await? { + break Ok(user); } buffer.clear(); }?; diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 3618467..145033b 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -421,6 +421,45 @@ async fn scenario_cap_full_negotiation() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_cap_full_negotiation_nick_last() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + s.send("CAP LS 302").await?; + s.expect(":testserver CAP * LS :sasl=PLAIN").await?; + s.send("CAP REQ :sasl").await?; + s.expect(":testserver CAP * ACK :sasl").await?; + s.send("AUTHENTICATE PLAIN").await?; + s.expect(":testserver AUTHENTICATE +").await?; + s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password' + s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?; + s.expect(":testserver 903 tester :SASL authentication successful").await?; + s.send("CAP END").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.send("NICK tester").await?; + + s.expect_server_introduction("tester").await?; + s.expect_nothing().await?; + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + + stream.shutdown().await?; + + // wrap up + + server.server.terminate().await?; + Ok(()) +} + #[tokio::test] async fn scenario_cap_short_negotiation() -> Result<()> { let mut server = TestServer::start().await?; From fbb3d4f4f963bb0c1706d7f47f2a31c7f8931d7c Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Tue, 16 Apr 2024 17:44:34 +0200 Subject: [PATCH 16/37] xmpp: rewrite xml element parsers using coroutines --- crates/proto-xmpp/src/roster.rs | 47 +++++++--------------- crates/proto-xmpp/src/session.rs | 45 +++++++-------------- crates/proto-xmpp/src/xml/ignore.rs | 62 +++++++++++------------------ 3 files changed, 53 insertions(+), 101 deletions(-) diff --git a/crates/proto-xmpp/src/roster.rs b/crates/proto-xmpp/src/roster.rs index 4e89981..f7d0305 100644 --- a/crates/proto-xmpp/src/roster.rs +++ b/crates/proto-xmpp/src/roster.rs @@ -1,47 +1,30 @@ use quick_xml::events::{BytesStart, Event}; use crate::xml::*; -use anyhow::{anyhow as ffail, Result}; +use anyhow::{anyhow, Result}; +use quick_xml::name::ResolveResult; pub const XMLNS: &'static str = "jabber:iq:roster"; #[derive(PartialEq, Eq, Debug)] pub struct RosterQuery; -pub struct QueryParser(QueryParserInner); - -enum QueryParserInner { - Initial, - InQuery, -} - -impl Parser for QueryParser { - type Output = Result; - - fn consume<'a>( - self: Self, - namespace: quick_xml::name::ResolveResult, - event: &quick_xml::events::Event<'a>, - ) -> Continuation { - match self.0 { - QueryParserInner::Initial => match event { - Event::Start(_) => Continuation::Continue(QueryParser(QueryParserInner::InQuery)), - Event::Empty(_) => Continuation::Final(Ok(RosterQuery)), - _ => Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))), - }, - QueryParserInner::InQuery => match event { - Event::End(_) => Continuation::Final(Ok(RosterQuery)), - _ => Continuation::Final(Err(ffail!("Unexpected XML event: {event:?}"))), - }, - } - } -} - impl FromXml for RosterQuery { - type P = QueryParser; + type P = impl Parser>; fn parse() -> Self::P { - QueryParser(QueryParserInner::Initial) + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + match event { + Event::Start(_) => (), + Event::Empty(_) => return Ok(RosterQuery), + _ => return Err(anyhow!("Unexpected XML event: {event:?}")), + } + (namespace, event) = yield; + match event { + Event::End(_) => return Ok(RosterQuery), + _ => return Err(anyhow!("Unexpected XML event: {event:?}")), + } + } } } diff --git a/crates/proto-xmpp/src/session.rs b/crates/proto-xmpp/src/session.rs index 569742d..59d394d 100644 --- a/crates/proto-xmpp/src/session.rs +++ b/crates/proto-xmpp/src/session.rs @@ -2,46 +2,29 @@ use quick_xml::events::{BytesStart, Event}; use crate::xml::*; use anyhow::{anyhow, Result}; +use quick_xml::name::ResolveResult; pub const XMLNS: &'static str = "urn:ietf:params:xml:ns:xmpp-session"; #[derive(PartialEq, Eq, Debug)] pub struct Session; -pub struct SessionParser(SessionParserInner); - -enum SessionParserInner { - Initial, - InSession, -} - -impl Parser for SessionParser { - type Output = Result; - - fn consume<'a>( - self: Self, - namespace: quick_xml::name::ResolveResult, - event: &quick_xml::events::Event<'a>, - ) -> Continuation { - match self.0 { - SessionParserInner::Initial => match event { - Event::Start(_) => Continuation::Continue(SessionParser(SessionParserInner::InSession)), - Event::Empty(_) => Continuation::Final(Ok(Session)), - _ => Continuation::Final(Err(anyhow!("Unexpected XML event: {event:?}"))), - }, - SessionParserInner::InSession => match event { - Event::End(_) => Continuation::Final(Ok(Session)), - _ => Continuation::Final(Err(anyhow!("Unexpected XML event: {event:?}"))), - }, - } - } -} - impl FromXml for Session { - type P = SessionParser; + type P = impl Parser>; fn parse() -> Self::P { - SessionParser(SessionParserInner::Initial) + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + match event { + Event::Start(_) => (), + Event::Empty(_) => return Ok(Session), + _ => return Err(anyhow!("Unexpected XML event: {event:?}")), + } + (namespace, event) = yield; + match event { + Event::End(_) => return Ok(Session), + _ => return Err(anyhow!("Unexpected XML event: {event:?}")), + } + } } } diff --git a/crates/proto-xmpp/src/xml/ignore.rs b/crates/proto-xmpp/src/xml/ignore.rs index f4af358..fc89410 100644 --- a/crates/proto-xmpp/src/xml/ignore.rs +++ b/crates/proto-xmpp/src/xml/ignore.rs @@ -1,49 +1,35 @@ use super::*; -use derive_more::From; #[derive(Default, Debug, PartialEq, Eq)] pub struct Ignore; -#[derive(From)] -pub struct IgnoreParser(IgnoreParserInner); - -enum IgnoreParserInner { - Initial, - InTag { name: Vec, depth: u8 }, -} - -impl Parser for IgnoreParser { - type Output = Result; - - fn consume<'a>(self: Self, _: ResolveResult, event: &Event<'a>) -> Continuation { - match self.0 { - IgnoreParserInner::Initial => match event { - Event::Start(bytes) => { - let name = bytes.name().0.to_owned(); - Continuation::Continue(IgnoreParserInner::InTag { name, depth: 0 }.into()) - } - Event::Empty(_) => Continuation::Final(Ok(Ignore)), - _ => Continuation::Final(Ok(Ignore)), - }, - IgnoreParserInner::InTag { name, depth } => match event { - Event::End(bytes) if name == bytes.name().0 => { - if depth == 0 { - Continuation::Final(Ok(Ignore)) - } else { - Continuation::Continue(IgnoreParserInner::InTag { name, depth: depth - 1 }.into()) - } - } - _ => Continuation::Continue(IgnoreParserInner::InTag { name, depth }.into()), - }, - } - } -} - impl FromXml for Ignore { - type P = IgnoreParser; + type P = impl Parser>; fn parse() -> Self::P { - IgnoreParserInner::Initial.into() + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + let mut depth = match event { + Event::Start(bytes) => 0, + Event::Empty(_) => return Ok(Ignore), + _ => return Ok(Ignore), + }; + loop { + (namespace, event) = yield; + match event { + Event::End(_) => { + if depth == 0 { + return Ok(Ignore); + } else { + depth -= 1; + } + } + Event::Start(_) => { + depth += 1; + } + _ => (), + } + } + } } } From 02a8309d9e5cb663d424f3e417779532c29c771f Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Thu, 18 Apr 2024 01:42:28 +0200 Subject: [PATCH 17/37] xmpp: relax the jid regex a bit --- crates/proto-xmpp/src/bind.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/proto-xmpp/src/bind.rs b/crates/proto-xmpp/src/bind.rs index 41c9e45..dc0d1ce 100644 --- a/crates/proto-xmpp/src/bind.rs +++ b/crates/proto-xmpp/src/bind.rs @@ -48,7 +48,7 @@ impl Jid { use lazy_static::lazy_static; use regex::Regex; lazy_static! { - static ref RE: Regex = Regex::new(r"^(([a-zA-Z]+)@)?([a-zA-Z.]+)(/([a-zA-Z\-]+))?$").unwrap(); + static ref RE: Regex = Regex::new(r"^(([a-zA-Z0-9]+)@)?([^@/]+)(/([a-zA-Z0-9\-]+))?$").unwrap(); } let m = RE.captures(i).ok_or(anyhow!("Incorrectly format jid: {i}"))?; From cebe3541791c0a28cbd986323a85457d7a39b199 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Fri, 19 Apr 2024 14:27:19 +0200 Subject: [PATCH 18/37] update libraries --- Cargo.lock | 243 ++++++++++++++++++---------------- Cargo.toml | 2 +- crates/lavina-core/Cargo.toml | 2 +- rust-toolchain | 2 +- 4 files changed, 131 insertions(+), 118 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 93eac0b..a3e6e02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,9 +41,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.16" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" [[package]] name = "android-tzdata" @@ -110,9 +110,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.81" +version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" +checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" [[package]] name = "assert_matches" @@ -140,15 +140,15 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" [[package]] name = "backtrace" -version = "0.3.69" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ "addr2line", "cc", @@ -165,6 +165,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" + [[package]] name = "base64ct" version = "1.6.0" @@ -197,9 +203,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.15.4" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" @@ -215,15 +221,15 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "cc" -version = "1.0.90" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" +checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7" [[package]] name = "cfg-if" @@ -233,23 +239,23 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.37" +version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ "android-tzdata", "iana-time-zone", "js-sys", "num-traits", "wasm-bindgen", - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] name = "clap" -version = "4.5.3" +version = "4.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "949626d00e063efc93b6dca932419ceb5432f99769911c0b995f7e884c778813" +checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" dependencies = [ "clap_builder", "clap_derive", @@ -269,14 +275,14 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.3" +version = "4.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90239a040c80f5e14809ca132ddc4176ab33d5e17e49691793296e3fcb34d72f" +checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.60", ] [[package]] @@ -320,9 +326,9 @@ dependencies = [ [[package]] name = "crc" -version = "3.0.1" +version = "3.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86ec7a15cbe22e59248fc7eadb1907dab5ba09372595da4d73dd805ed4417dfe" +checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" dependencies = [ "crc-catalog", ] @@ -360,9 +366,9 @@ dependencies = [ [[package]] name = "der" -version = "0.7.8" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" +checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" dependencies = [ "const-oid", "pem-rfc7468", @@ -402,9 +408,9 @@ checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" [[package]] name = "either" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" dependencies = [ "serde", ] @@ -444,15 +450,15 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" [[package]] name = "fastrand" -version = "2.0.1" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" [[package]] name = "figment" -version = "0.10.15" +version = "0.10.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7270677e7067213e04f323b55084586195f18308cd7546cfac9f873344ccceb6" +checksum = "d032832d74006f99547004d49410a4b4218e4c33382d56ca3ff89df74f86b953" dependencies = [ "atomic", "pear", @@ -546,7 +552,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.60", ] [[package]] @@ -590,9 +596,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" dependencies = [ "cfg-if", "libc", @@ -726,9 +732,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" [[package]] name = "hyper" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a" +checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" dependencies = [ "bytes", "futures-channel", @@ -799,9 +805,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.5" +version = "2.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", "hashbrown", @@ -830,9 +836,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "js-sys" @@ -950,9 +956,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.1" +version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "mgmt-api" @@ -1146,7 +1152,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.53", + "syn 2.0.60", ] [[package]] @@ -1181,14 +1187,14 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.60", ] [[package]] name = "pin-project-lite" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" [[package]] name = "pin-utils" @@ -1231,9 +1237,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.79" +version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" +checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" dependencies = [ "unicode-ident", ] @@ -1246,7 +1252,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.60", "version_check", "yansi", ] @@ -1342,9 +1348,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.35" +version = "1.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" dependencies = [ "proc-macro2", ] @@ -1390,9 +1396,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.3" +version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", @@ -1413,17 +1419,17 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" [[package]] name = "reqwest" -version = "0.12.0" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58b48d98d932f4ee75e541614d32a7f44c889b72bd9c2e04d95edd135989df88" +checksum = "3e6cc1e89e689536eb5aeede61520e874df5a4707df811cd5da4aa5fbb2aae19" dependencies = [ - "base64", + "base64 0.22.0", "bytes", "futures-core", "futures-util", @@ -1533,7 +1539,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64", + "base64 0.21.7", ] [[package]] @@ -1557,7 +1563,7 @@ name = "sasl" version = "0.0.2-dev" dependencies = [ "anyhow", - "base64", + "base64 0.22.0", ] [[package]] @@ -1584,29 +1590,29 @@ checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" [[package]] name = "serde" -version = "1.0.197" +version = "1.0.198" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" +checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.197" +version = "1.0.198" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" +checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.60", ] [[package]] name = "serde_json" -version = "1.0.114" +version = "1.0.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" +checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" dependencies = [ "itoa", "ryu", @@ -1695,9 +1701,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" @@ -1839,7 +1845,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418" dependencies = [ "atoi", - "base64", + "base64 0.21.7", "bitflags 2.5.0", "byteorder", "bytes", @@ -1881,7 +1887,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e" dependencies = [ "atoi", - "base64", + "base64 0.21.7", "bitflags 2.5.0", "byteorder", "crc", @@ -1948,9 +1954,9 @@ dependencies = [ [[package]] name = "strsim" -version = "0.11.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "subtle" @@ -1971,9 +1977,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.53" +version = "2.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" +checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" dependencies = [ "proc-macro2", "quote", @@ -2015,7 +2021,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.60", ] [[package]] @@ -2045,9 +2051,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.36.0" +version = "1.37.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" +checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" dependencies = [ "backtrace", "bytes", @@ -2070,7 +2076,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.60", ] [[package]] @@ -2106,9 +2112,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.9" +version = "0.22.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e40bb779c5187258fd7aad0eb68cb8706a0a81fa712fbea808ab43c4b8374c4" +checksum = "fb686a972ccef8537b39eead3968b0e8616cb5040dbb9bba93007c8e07c9215f" dependencies = [ "indexmap", "serde", @@ -2165,7 +2171,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.60", ] [[package]] @@ -2355,7 +2361,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.60", "wasm-bindgen-shared", ] @@ -2389,7 +2395,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.60", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2448,7 +2454,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -2466,7 +2472,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -2486,17 +2492,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" dependencies = [ - "windows_aarch64_gnullvm 0.52.4", - "windows_aarch64_msvc 0.52.4", - "windows_i686_gnu 0.52.4", - "windows_i686_msvc 0.52.4", - "windows_x86_64_gnu 0.52.4", - "windows_x86_64_gnullvm 0.52.4", - "windows_x86_64_msvc 0.52.4", + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", ] [[package]] @@ -2507,9 +2514,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" [[package]] name = "windows_aarch64_msvc" @@ -2519,9 +2526,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" [[package]] name = "windows_i686_gnu" @@ -2531,9 +2538,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" [[package]] name = "windows_i686_msvc" @@ -2543,9 +2556,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" [[package]] name = "windows_x86_64_gnu" @@ -2555,9 +2568,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" [[package]] name = "windows_x86_64_gnullvm" @@ -2567,9 +2580,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" [[package]] name = "windows_x86_64_msvc" @@ -2579,24 +2592,24 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" [[package]] name = "winnow" -version = "0.6.5" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dffa400e67ed5a4dd237983829e66475f0a4a26938c4b04c21baede6262215b8" +checksum = "f0c976aaaa0e1f90dbb21e9587cdaf1d9679a1cde8875c0d6bd83ab96a208352" dependencies = [ "memchr", ] [[package]] name = "winreg" -version = "0.50.0" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +checksum = "a277a57398d4bfa075df44f501a17cfdf8542d224f0d36095a2adc7aee4ef0a5" dependencies = [ "cfg-if", "windows-sys 0.48.0", @@ -2625,7 +2638,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.60", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 1751f0f..7234d24 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ clap = { version = "4.4.4", features = ["derive"] } serde = { version = "1.0.152", features = ["rc", "serde_derive"] } tracing = "0.1.37" # logging & tracing api prometheus = { version = "0.13.3", default-features = false } -base64 = "0.21.3" +base64 = "0.22.0" lavina-core = { path = "crates/lavina-core" } tracing-subscriber = "0.3.16" sasl = { path = "crates/sasl" } diff --git a/crates/lavina-core/Cargo.toml b/crates/lavina-core/Cargo.toml index 941b753..531e959 100644 --- a/crates/lavina-core/Cargo.toml +++ b/crates/lavina-core/Cargo.toml @@ -5,7 +5,7 @@ version.workspace = true [dependencies] anyhow.workspace = true -sqlx = { version = "0.7.0-alpha.2", features = ["sqlite", "migrate"] } +sqlx = { version = "0.7.4", features = ["sqlite", "migrate"] } serde.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/rust-toolchain b/rust-toolchain index a693462..4ac8229 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2024-03-20 +nightly-2024-04-19 From 5a09b743c9170f6844894b4becafc7afa2fd6334 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sat, 20 Apr 2024 17:09:44 +0200 Subject: [PATCH 19/37] return AlreadyJoined when a player attempts to join a room they are already in --- crates/lavina-core/src/player.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 0486808..8693752 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -177,6 +177,7 @@ pub enum ClientCommand { pub enum JoinResult { Success(RoomInfo), + AlreadyJoined, Banned, } @@ -388,6 +389,9 @@ impl Player { if self.banned_from.contains(&room_id) { return JoinResult::Banned; } + if self.my_rooms.contains_key(&room_id) { + return JoinResult::AlreadyJoined; + } let room = match self.rooms.get_or_create_room(room_id.clone()).await { Ok(room) => room, From ddb348bee929a3aa7b8d6545a96f5448280b4ce0 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sun, 21 Apr 2024 19:45:50 +0200 Subject: [PATCH 20/37] refactor lavina core by grouping public services into a new LavinaCore struct. this will be useful in future when additional services will be introduced and passed as dependencies --- crates/lavina-core/src/lib.rs | 29 +++++++++++++++++++++++ crates/projection-irc/src/lib.rs | 14 +++++------ crates/projection-irc/tests/lib.rs | 36 +++++++++++------------------ crates/projection-xmpp/src/lib.rs | 16 ++++++------- crates/projection-xmpp/tests/lib.rs | 18 ++++++--------- src/http.rs | 17 +++++++------- src/main.rs | 33 ++++++-------------------- 7 files changed, 79 insertions(+), 84 deletions(-) diff --git a/crates/lavina-core/src/lib.rs b/crates/lavina-core/src/lib.rs index 401e49e..ff52363 100644 --- a/crates/lavina-core/src/lib.rs +++ b/crates/lavina-core/src/lib.rs @@ -1,4 +1,11 @@ //! Domain definitions and implementation of common chat logic. +use anyhow::Result; +use prometheus::Registry as MetricsRegistry; + +use crate::player::PlayerRegistry; +use crate::repo::Storage; +use crate::room::RoomRegistry; + pub mod player; pub mod prelude; pub mod repo; @@ -6,3 +13,25 @@ pub mod room; pub mod terminator; mod table; + +#[derive(Clone)] +pub struct LavinaCore { + pub players: PlayerRegistry, + pub rooms: RoomRegistry, +} + +impl LavinaCore { + pub async fn new(mut metrics: MetricsRegistry, storage: Storage) -> Result { + // TODO shutdown all services in reverse order on error + let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; + let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics)?; + Ok(LavinaCore { players, rooms }) + } + + pub async fn shutdown(mut self) -> Result<()> { + self.players.shutdown_all().await?; + drop(self.players); + drop(self.rooms); + Ok(()) + } +} diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index 7f1b49e..1513546 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -18,6 +18,7 @@ use lavina_core::prelude::*; use lavina_core::repo::Storage; use lavina_core::room::{RoomId, RoomInfo, RoomRegistry}; use lavina_core::terminator::Terminator; +use lavina_core::LavinaCore; use proto_irc::client::CapabilitySubcommand; use proto_irc::client::{client_message, ClientMessage}; use proto_irc::server::CapSubBody; @@ -54,8 +55,7 @@ async fn handle_socket( config: ServerConfig, mut stream: TcpStream, socket_addr: &SocketAddr, - players: PlayerRegistry, - rooms: RoomRegistry, + mut core: LavinaCore, termination: Deferred<()>, // TODO use it to stop the connection gracefully mut storage: Storage, ) -> Result<()> { @@ -75,7 +75,7 @@ async fn handle_socket( match registered_user { Ok(user) => { log::debug!("User registered"); - handle_registered_socket(config, players, rooms, &mut reader, &mut writer, user).await?; + handle_registered_socket(config, core.players, core.rooms, &mut reader, &mut writer, user).await?; } Err(err) => { log::debug!("Registration failed: {err}"); @@ -942,8 +942,7 @@ impl RunningServer { pub async fn launch( config: ServerConfig, - players: PlayerRegistry, - rooms: RoomRegistry, + core: LavinaCore, metrics: MetricsRegistry, storage: Storage, ) -> Result { @@ -984,13 +983,12 @@ pub async fn launch( } let terminator = Terminator::spawn(|termination| { - let players = players.clone(); - let rooms = rooms.clone(); + let core = core.clone(); let current_connections_clone = current_connections.clone(); let stopped_tx = stopped_tx.clone(); let storage = storage.clone(); async move { - match handle_socket(config, stream, &socket_addr, players, rooms, termination, storage).await { + match handle_socket(config, stream, &socket_addr, core, termination, storage).await { Ok(_) => log::info!("Connection terminated"), Err(err) => log::warn!("Connection failed: {err}"), } diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 145033b..f2f4505 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -9,7 +9,7 @@ use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; use lavina_core::repo::{Storage, StorageConfig}; -use lavina_core::{player::PlayerRegistry, room::RoomRegistry}; +use lavina_core::LavinaCore; use projection_irc::APP_VERSION; use projection_irc::{launch, read_irc_message, RunningServer, ServerConfig}; struct TestScope<'a> { @@ -94,8 +94,7 @@ impl<'a> TestScope<'a> { struct TestServer { metrics: MetricsRegistry, storage: Storage, - rooms: RoomRegistry, - players: PlayerRegistry, + core: LavinaCore, server: RunningServer, } impl TestServer { @@ -110,43 +109,36 @@ impl TestServer { db_path: ":memory:".into(), }) .await?; - let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); - let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); + let core = LavinaCore::new(metrics.clone(), storage.clone()).await?; + let server = launch(config, core.clone(), metrics.clone(), storage.clone()).await.unwrap(); Ok(TestServer { metrics, storage, - rooms, - players, + core, server, }) } - async fn reboot(mut self) -> Result { + async fn reboot(self) -> Result { let config = ServerConfig { listen_on: "127.0.0.1:0".parse().unwrap(), server_name: "testserver".into(), }; let TestServer { - mut metrics, - mut storage, - rooms, - mut players, + metrics: _, + storage, + mut core, server, } = self; server.terminate().await?; - players.shutdown_all().await.unwrap(); - drop(players); - drop(rooms); - let mut metrics = MetricsRegistry::new(); - let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); - let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); + core.shutdown().await?; + let metrics = MetricsRegistry::new(); + let core = LavinaCore::new(metrics.clone(), storage.clone()).await?; + let server = launch(config, core.clone(), metrics.clone(), storage.clone()).await.unwrap(); Ok(TestServer { metrics, storage, - rooms, - players, + core, server, }) } diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 30e0a3c..6539254 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -26,6 +26,7 @@ use lavina_core::prelude::*; use lavina_core::repo::Storage; use lavina_core::room::RoomRegistry; use lavina_core::terminator::Terminator; +use lavina_core::LavinaCore; use proto_xmpp::bind::{Name, Resource}; use proto_xmpp::stream::*; use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml}; @@ -79,8 +80,7 @@ impl RunningServer { pub async fn launch( config: ServerConfig, - players: PlayerRegistry, - rooms: RoomRegistry, + core: LavinaCore, metrics: MetricsRegistry, storage: Storage, ) -> Result { @@ -122,15 +122,14 @@ pub async fn launch( // TODO kill the older connection and restart it continue; } - let players = players.clone(); - let rooms = rooms.clone(); + let core = core.clone(); let storage = storage.clone(); let hostname = config.hostname.clone(); let terminator = Terminator::spawn(|termination| { let stopped_tx = stopped_tx.clone(); let loaded_config = loaded_config.clone(); async move { - match handle_socket(loaded_config, stream, &socket_addr, players, rooms, storage, hostname, termination).await { + match handle_socket(loaded_config, stream, &socket_addr, core, storage, hostname, termination).await { Ok(_) => log::info!("Connection terminated"), Err(err) => log::warn!("Connection failed: {err}"), } @@ -168,8 +167,7 @@ async fn handle_socket( cert_config: Arc, mut stream: TcpStream, socket_addr: &SocketAddr, - mut players: PlayerRegistry, - rooms: RoomRegistry, + mut core: LavinaCore, mut storage: Storage, hostname: Str, termination: Deferred<()>, // TODO use it to stop the connection gracefully @@ -207,14 +205,14 @@ async fn handle_socket( authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage, &hostname) => { match authenticated { Ok(authenticated) => { - let mut connection = players.connect_to_player(&authenticated.player_id).await; + let mut connection = core.players.connect_to_player(&authenticated.player_id).await; socket_final( &mut xml_reader, &mut xml_writer, &mut reader_buf, &authenticated, &mut connection, - &rooms, + &core.rooms, &hostname, ) .await?; diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 29d0368..be687a4 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -16,9 +16,8 @@ use tokio_rustls::rustls::client::ServerCertVerifier; use tokio_rustls::rustls::{ClientConfig, ServerName}; use tokio_rustls::TlsConnector; -use lavina_core::player::PlayerRegistry; use lavina_core::repo::{Storage, StorageConfig}; -use lavina_core::room::RoomRegistry; +use lavina_core::LavinaCore; use projection_xmpp::{launch, RunningServer, ServerConfig}; use proto_xmpp::xml::{Continuation, FromXml, Parser}; @@ -124,8 +123,7 @@ impl ServerCertVerifier for IgnoreCertVerification { struct TestServer { metrics: MetricsRegistry, storage: Storage, - rooms: RoomRegistry, - players: PlayerRegistry, + core: LavinaCore, server: RunningServer, } impl TestServer { @@ -137,19 +135,17 @@ impl TestServer { key: "tests/certs/xmpp.key".parse().unwrap(), hostname: "localhost".into(), }; - let mut metrics = MetricsRegistry::new(); - let mut storage = Storage::open(StorageConfig { + let metrics = MetricsRegistry::new(); + let storage = Storage::open(StorageConfig { db_path: ":memory:".into(), }) .await?; - let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); - let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics).unwrap(); - let server = launch(config, players.clone(), rooms.clone(), metrics.clone(), storage.clone()).await.unwrap(); + let core = LavinaCore::new(metrics.clone(), storage.clone()).await?; + let server = launch(config, core.clone(), metrics.clone(), storage.clone()).await.unwrap(); Ok(TestServer { metrics, storage, - rooms, - players, + core, server, }) } diff --git a/src/http.rs b/src/http.rs index 89ba4ce..302bf5f 100644 --- a/src/http.rs +++ b/src/http.rs @@ -16,6 +16,7 @@ use lavina_core::prelude::*; use lavina_core::repo::Storage; use lavina_core::room::RoomRegistry; use lavina_core::terminator::Terminator; +use lavina_core::LavinaCore; use mgmt_api::*; @@ -29,20 +30,20 @@ pub struct ServerConfig { pub async fn launch( config: ServerConfig, metrics: MetricsRegistry, - rooms: RoomRegistry, + core: LavinaCore, storage: Storage, ) -> Result { log::info!("Starting the http service"); let listener = TcpListener::bind(config.listen_on).await?; log::debug!("Listener started"); - let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, rooms, storage, rx.map(|_| ()))); + let terminator = Terminator::spawn(|rx| main_loop(listener, metrics, core, storage, rx.map(|_| ()))); Ok(terminator) } async fn main_loop( listener: TcpListener, metrics: MetricsRegistry, - rooms: RoomRegistry, + core: LavinaCore, storage: Storage, termination: impl Future, ) -> Result<()> { @@ -55,13 +56,13 @@ async fn main_loop( let (stream, _) = result?; let stream = TokioIo::new(stream); let metrics = metrics.clone(); - let rooms = rooms.clone(); + let core = core.clone(); let storage = storage.clone(); tokio::task::spawn(async move { let registry = metrics.clone(); - let rooms = rooms.clone(); + let core = core.clone(); let storage = storage.clone(); - let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(registry.clone(), rooms.clone(), storage.clone(), r))); + let server = http1::Builder::new().serve_connection(stream, service_fn(move |r| route(registry.clone(), core.clone(), storage.clone(), r))); if let Err(err) = server.await { tracing::error!("Error serving connection: {:?}", err); } @@ -75,13 +76,13 @@ async fn main_loop( async fn route( registry: MetricsRegistry, - rooms: RoomRegistry, + core: LavinaCore, storage: Storage, request: Request, ) -> HttpResult>> { let res = match (request.method(), request.uri().path()) { (&Method::GET, "/metrics") => endpoint_metrics(registry), - (&Method::GET, "/rooms") => endpoint_rooms(rooms).await, + (&Method::GET, "/rooms") => endpoint_rooms(core.rooms).await, (&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(), (&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(), _ => not_found(), diff --git a/src/main.rs b/src/main.rs index 0d03a89..98c45f8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,10 +9,9 @@ use figment::{providers::Toml, Figment}; use prometheus::Registry as MetricsRegistry; use serde::Deserialize; -use lavina_core::player::PlayerRegistry; use lavina_core::prelude::*; use lavina_core::repo::Storage; -use lavina_core::room::RoomRegistry; +use lavina_core::LavinaCore; #[derive(Deserialize, Debug)] struct ServerConfig { @@ -49,27 +48,12 @@ async fn main() -> Result<()> { xmpp: xmpp_config, storage: storage_config, } = config; - let mut metrics = MetricsRegistry::new(); + let metrics = MetricsRegistry::new(); let storage = Storage::open(storage_config).await?; - let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; - let mut players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics)?; - let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), rooms.clone(), storage.clone()).await?; - let irc = projection_irc::launch( - irc_config, - players.clone(), - rooms.clone(), - metrics.clone(), - storage.clone(), - ) - .await?; - let xmpp = projection_xmpp::launch( - xmpp_config, - players.clone(), - rooms.clone(), - metrics.clone(), - storage.clone(), - ) - .await?; + let core = LavinaCore::new(metrics.clone(), storage.clone()).await?; + let telemetry_terminator = http::launch(telemetry_config, metrics.clone(), core.clone(), storage.clone()).await?; + let irc = projection_irc::launch(irc_config, core.clone(), metrics.clone(), storage.clone()).await?; + let xmpp = projection_xmpp::launch(xmpp_config, core.clone(), metrics.clone(), storage.clone()).await?; tracing::info!("Started"); sleep.await; @@ -78,10 +62,7 @@ async fn main() -> Result<()> { xmpp.terminate().await?; irc.terminate().await?; telemetry_terminator.terminate().await?; - players.shutdown_all().await?; - drop(players); - drop(rooms); - storage.close().await?; + core.shutdown().await?; tracing::info!("Shutdown complete"); Ok(()) } From 12d30ca5c204c38170e92f27f3ec1ad3cabd3dfb Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sun, 21 Apr 2024 21:00:44 +0000 Subject: [PATCH 21/37] irc: implement server-time capability for incoming messages (#52) Spec: https://ircv3.net/specs/extensions/server-time Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/52 --- Cargo.lock | 1 + Cargo.toml | 1 + crates/lavina-core/Cargo.toml | 2 +- crates/lavina-core/src/player.rs | 24 +++++--- crates/lavina-core/src/repo/mod.rs | 12 +++- crates/lavina-core/src/room.rs | 18 ++++-- crates/projection-irc/Cargo.toml | 1 + crates/projection-irc/src/cap.rs | 3 +- crates/projection-irc/src/lib.rs | 40 +++++++++++-- crates/projection-irc/tests/lib.rs | 85 +++++++++++++++++++++++++-- crates/projection-xmpp/src/updates.rs | 1 + crates/proto-irc/src/lib.rs | 15 ++++- crates/proto-irc/src/server.rs | 7 +++ 13 files changed, 183 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a3e6e02..a586d14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1263,6 +1263,7 @@ version = "0.0.2-dev" dependencies = [ "anyhow", "bitflags 2.5.0", + "chrono", "futures-util", "lavina-core", "nonempty", diff --git a/Cargo.toml b/Cargo.toml index 7234d24..0158e42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ base64 = "0.22.0" lavina-core = { path = "crates/lavina-core" } tracing-subscriber = "0.3.16" sasl = { path = "crates/sasl" } +chrono = "0.4.37" [package] name = "lavina" diff --git a/crates/lavina-core/Cargo.toml b/crates/lavina-core/Cargo.toml index 531e959..92bf798 100644 --- a/crates/lavina-core/Cargo.toml +++ b/crates/lavina-core/Cargo.toml @@ -10,4 +10,4 @@ serde.workspace = true tokio.workspace = true tracing.workspace = true prometheus.workspace = true -chrono = "0.4.37" +chrono.workspace = true diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 8693752..9925709 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -10,6 +10,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use chrono::{DateTime, Utc}; use prometheus::{IntGauge, Registry as MetricsRegistry}; use serde::Serialize; use tokio::sync::mpsc::{channel, Receiver, Sender}; @@ -57,7 +58,7 @@ pub struct PlayerConnection { } impl PlayerConnection { /// Handled in [Player::send_message]. - pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result<()> { + pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result { let (promise, deferred) = oneshot(); let cmd = ClientCommand::SendMessage { room_id, body, promise }; self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; @@ -163,7 +164,7 @@ pub enum ClientCommand { SendMessage { room_id: RoomId, body: Str, - promise: Promise<()>, + promise: Promise, }, ChangeTopic { room_id: RoomId, @@ -181,6 +182,11 @@ pub enum JoinResult { Banned, } +pub enum SendMessageResult { + Success(DateTime), + NoSuchRoom, +} + /// Player update event type which is sent to a player actor and from there to a connection handler. #[derive(Clone, Debug)] pub enum Updates { @@ -192,6 +198,7 @@ pub enum Updates { room_id: RoomId, author_id: PlayerId, body: Str, + created_at: DateTime, }, RoomJoined { room_id: RoomId, @@ -367,8 +374,8 @@ impl Player { let _ = promise.send(()); } ClientCommand::SendMessage { room_id, body, promise } => { - self.send_message(connection_id, room_id, body).await; - let _ = promise.send(()); + let result = self.send_message(connection_id, room_id, body).await; + let _ = promise.send(result); } ClientCommand::ChangeTopic { room_id, @@ -425,18 +432,21 @@ impl Player { self.broadcast_update(update, connection_id).await; } - async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) { + async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) -> SendMessageResult { let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("no room found"); - return; + return SendMessageResult::NoSuchRoom; }; - room.send_message(&self.player_id, body.clone()).await; + let created_at = chrono::Utc::now(); + room.send_message(&self.player_id, body.clone(), created_at.clone()).await; let update = Updates::NewMessage { room_id, author_id: self.player_id.clone(), body, + created_at, }; self.broadcast_update(update, connection_id).await; + SendMessageResult::Success(created_at) } async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) { diff --git a/crates/lavina-core/src/repo/mod.rs b/crates/lavina-core/src/repo/mod.rs index e8e3854..645c764 100644 --- a/crates/lavina-core/src/repo/mod.rs +++ b/crates/lavina-core/src/repo/mod.rs @@ -4,6 +4,7 @@ use std::str::FromStr; use std::sync::Arc; use anyhow::anyhow; +use chrono::{DateTime, Utc}; use serde::Deserialize; use sqlx::sqlite::SqliteConnectOptions; use sqlx::{ConnectOptions, Connection, FromRow, Sqlite, SqliteConnection, Transaction}; @@ -80,7 +81,14 @@ impl Storage { Ok(id) } - pub async fn insert_message(&mut self, room_id: u32, id: u32, content: &str, author_id: &str) -> Result<()> { + pub async fn insert_message( + &mut self, + room_id: u32, + id: u32, + content: &str, + author_id: &str, + created_at: &DateTime, + ) -> Result<()> { let mut executor = self.conn.lock().await; let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;") .bind(author_id) @@ -98,7 +106,7 @@ impl Storage { .bind(id) .bind(content) .bind(author_id) - .bind(chrono::Utc::now().to_string()) + .bind(created_at.to_string()) .bind(room_id) .execute(&mut *executor) .await?; diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index a5e2dab..52ac7c4 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -2,6 +2,7 @@ use std::collections::HashSet; use std::{collections::HashMap, hash::Hash, sync::Arc}; +use chrono::{DateTime, Utc}; use prometheus::{IntGauge, Registry as MetricRegistry}; use serde::Serialize; use tokio::sync::RwLock as AsyncRwLock; @@ -163,9 +164,9 @@ impl RoomHandle { lock.broadcast_update(update, player_id).await; } - pub async fn send_message(&self, player_id: &PlayerId, body: Str) { + pub async fn send_message(&self, player_id: &PlayerId, body: Str, created_at: DateTime) { let mut lock = self.0.write().await; - let res = lock.send_message(player_id, body).await; + let res = lock.send_message(player_id, body, created_at).await; if let Err(err) = res { log::warn!("Failed to send message: {err:?}"); } @@ -208,14 +209,23 @@ struct Room { storage: Storage, } impl Room { - async fn send_message(&mut self, author_id: &PlayerId, body: Str) -> Result<()> { + async fn send_message(&mut self, author_id: &PlayerId, body: Str, created_at: DateTime) -> Result<()> { tracing::info!("Adding a message to room"); - self.storage.insert_message(self.storage_id, self.message_count, &body, &*author_id.as_inner()).await?; + self.storage + .insert_message( + self.storage_id, + self.message_count, + &body, + &*author_id.as_inner(), + &created_at, + ) + .await?; self.message_count += 1; let update = Updates::NewMessage { room_id: self.room_id.clone(), author_id: author_id.clone(), body, + created_at, }; self.broadcast_update(update, author_id).await; Ok(()) diff --git a/crates/projection-irc/Cargo.toml b/crates/projection-irc/Cargo.toml index 3135280..7275f69 100644 --- a/crates/projection-irc/Cargo.toml +++ b/crates/projection-irc/Cargo.toml @@ -12,6 +12,7 @@ tokio.workspace = true prometheus.workspace = true futures-util.workspace = true nonempty.workspace = true +chrono.workspace = true bitflags = "2.4.1" proto-irc = { path = "../proto-irc" } sasl = { path = "../sasl" } diff --git a/crates/projection-irc/src/cap.rs b/crates/projection-irc/src/cap.rs index af0e3ff..83f1e24 100644 --- a/crates/projection-irc/src/cap.rs +++ b/crates/projection-irc/src/cap.rs @@ -1,9 +1,10 @@ use bitflags::bitflags; bitflags! { - #[derive(Debug)] + #[derive(Debug, Clone, Copy)] pub struct Capabilities: u32 { const None = 0; const Sasl = 1 << 0; + const ServerTime = 1 << 1; } } diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index 1513546..ce450c1 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use std::net::SocketAddr; use anyhow::{anyhow, Result}; +use chrono::SecondsFormat; use futures_util::future::join_all; use nonempty::nonempty; use nonempty::NonEmpty; @@ -24,7 +25,7 @@ use proto_irc::client::{client_message, ClientMessage}; use proto_irc::server::CapSubBody; use proto_irc::server::{AwayStatus, ServerMessage, ServerMessageBody}; use proto_irc::user::PrefixedNick; -use proto_irc::{Chan, Recipient}; +use proto_irc::{Chan, Recipient, Tag}; use sasl::AuthBody; mod cap; @@ -49,6 +50,7 @@ struct RegisteredUser { */ username: Str, realname: Str, + enabled_capabilities: Capabilities, } async fn handle_socket( @@ -136,7 +138,7 @@ impl RegistrationState { sender: Some(config.server_name.clone().into()), body: ServerMessageBody::Cap { target: self.future_nickname.clone().unwrap_or_else(|| "*".into()), - subcmd: CapSubBody::Ls("sasl=PLAIN".into()), + subcmd: CapSubBody::Ls("sasl=PLAIN server-time".into()), }, } .write_async(writer) @@ -156,16 +158,30 @@ impl RegistrationState { self.enabled_capabilities |= Capabilities::Sasl; } acked.push(cap); + } else if &*cap.name == "server-time" { + if cap.to_disable { + self.enabled_capabilities &= !Capabilities::ServerTime; + } else { + self.enabled_capabilities |= Capabilities::ServerTime; + } + acked.push(cap); } else { naked.push(cap); } } let mut ack_body = String::new(); - for cap in acked { - if cap.to_disable { + if let Some((first, tail)) = acked.split_first() { + if first.to_disable { ack_body.push('-'); } - ack_body += &*cap.name; + ack_body += &*first.name; + for cap in tail { + ack_body.push(' '); + if cap.to_disable { + ack_body.push('-'); + } + ack_body += &*cap.name; + } } ServerMessage { tags: vec![], @@ -195,6 +211,7 @@ impl RegistrationState { nickname: nickname.clone(), username, realname, + enabled_capabilities: self.enabled_capabilities, }; self.finalize_auth(candidate_user, writer, storage, config).await } @@ -208,6 +225,7 @@ impl RegistrationState { nickname: nickname.clone(), username: username.clone(), realname: realname.clone(), + enabled_capabilities: self.enabled_capabilities, }; self.finalize_auth(candidate_user, writer, storage, config).await } else { @@ -224,6 +242,7 @@ impl RegistrationState { nickname: nickname.clone(), username, realname, + enabled_capabilities: self.enabled_capabilities, }; self.finalize_auth(candidate_user, writer, storage, config).await } else { @@ -587,9 +606,18 @@ async fn handle_update( author_id, room_id, body, + created_at, } => { + let mut tags = vec![]; + if user.enabled_capabilities.contains(Capabilities::ServerTime) { + let tag = Tag { + key: "time".into(), + value: Some(created_at.to_rfc3339_opts(SecondsFormat::Millis, true).into()), + }; + tags.push(tag); + } ServerMessage { - tags: vec![], + tags, sender: Some(author_id.as_inner().clone()), body: ServerMessageBody::PrivateMessage { target: Recipient::Chan(Chan::Global(room_id.as_inner().clone())), diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index f2f4505..5a4eb7c 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -1,17 +1,20 @@ use std::io::ErrorKind; -use std::net::SocketAddr; use std::time::Duration; use anyhow::{anyhow, Result}; +use chrono::{DateTime, SecondsFormat}; use prometheus::Registry as MetricsRegistry; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; +use lavina_core::player::{JoinResult, PlayerId, SendMessageResult}; use lavina_core::repo::{Storage, StorageConfig}; +use lavina_core::room::RoomId; use lavina_core::LavinaCore; use projection_irc::APP_VERSION; use projection_irc::{launch, read_irc_message, RunningServer, ServerConfig}; + struct TestScope<'a> { reader: BufReader>, writer: WriteHalf<'a>, @@ -89,6 +92,11 @@ impl<'a> TestScope<'a> { Err(_) => Ok(()), } } + + async fn expect_cap_ls(&mut self) -> Result<()> { + self.expect(":testserver CAP * LS :sasl=PLAIN server-time").await?; + Ok(()) + } } struct TestServer { @@ -388,7 +396,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> { s.send("CAP LS 302").await?; s.send("NICK tester").await?; s.send("USER UserName 0 * :Real Name").await?; - s.expect(":testserver CAP * LS :sasl=PLAIN").await?; + s.expect_cap_ls().await?; s.send("CAP REQ :sasl").await?; s.expect(":testserver CAP tester ACK :sasl").await?; s.send("AUTHENTICATE PLAIN").await?; @@ -426,7 +434,7 @@ async fn scenario_cap_full_negotiation_nick_last() -> Result<()> { let mut s = TestScope::new(&mut stream); s.send("CAP LS 302").await?; - s.expect(":testserver CAP * LS :sasl=PLAIN").await?; + s.expect_cap_ls().await?; s.send("CAP REQ :sasl").await?; s.expect(":testserver CAP * ACK :sasl").await?; s.send("AUTHENTICATE PLAIN").await?; @@ -505,7 +513,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> { s.send("CAP LS 302").await?; s.send("NICK tester").await?; s.send("USER UserName 0 * :Real Name").await?; - s.expect(":testserver CAP * LS :sasl=PLAIN").await?; + s.expect_cap_ls().await?; s.send("CAP REQ :sasl").await?; s.expect(":testserver CAP tester ACK :sasl").await?; s.send("AUTHENTICATE SHA256").await?; @@ -558,3 +566,72 @@ async fn terminate_socket_scenario() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn server_time_capability() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + server.storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + + s.send("CAP LS 302").await?; + s.send("NICK tester").await?; + s.send("USER UserName 0 * :Real Name").await?; + s.expect_cap_ls().await?; + s.send("CAP REQ :sasl server-time").await?; + s.expect(":testserver CAP tester ACK :sasl server-time").await?; + s.send("AUTHENTICATE PLAIN").await?; + s.expect(":testserver AUTHENTICATE +").await?; + s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password' + s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?; + s.expect(":testserver 903 tester :SASL authentication successful").await?; + s.send("CAP END").await?; + s.expect_server_introduction("tester").await?; + s.expect_nothing().await?; + s.send("JOIN #test").await?; + s.expect(":tester JOIN #test").await?; + s.expect(":testserver 332 tester #test :New room").await?; + s.expect(":testserver 353 tester = #test :tester").await?; + s.expect(":testserver 366 tester #test :End of /NAMES list").await?; + + server.storage.create_user("some_guy").await?; + let mut conn = server.core.players.connect_to_player(&PlayerId::from("some_guy").unwrap()).await; + let res = conn.join_room(RoomId::from("test").unwrap()).await?; + let JoinResult::Success(_) = res else { + panic!("Failed to join room"); + }; + + s.expect(":some_guy JOIN #test").await?; + + let SendMessageResult::Success(res) = conn.send_message(RoomId::from("test").unwrap(), "Hello".into()).await? + else { + panic!("Failed to send message"); + }; + s.expect(&format!( + "@time={} :some_guy PRIVMSG #test :Hello", + res.to_rfc3339_opts(SecondsFormat::Millis, true) + )) + .await?; + + // formatting check + assert_eq!( + DateTime::parse_from_rfc3339(&"2024-01-01T10:00:32.123Z").unwrap().to_rfc3339_opts(SecondsFormat::Millis, true), + "2024-01-01T10:00:32.123Z" + ); + + s.send("QUIT :Leaving").await?; + s.expect(":testserver ERROR :Leaving the server").await?; + s.expect_eof().await?; + + stream.shutdown().await?; + + // wrap up + + server.server.terminate().await?; + Ok(()) +} diff --git a/crates/projection-xmpp/src/updates.rs b/crates/projection-xmpp/src/updates.rs index 0161b3f..fcc62b6 100644 --- a/crates/projection-xmpp/src/updates.rs +++ b/crates/projection-xmpp/src/updates.rs @@ -17,6 +17,7 @@ impl<'a> XmppConnection<'a> { room_id, author_id, body, + created_at: _, } => { Message::<()> { to: Some(Jid { diff --git a/crates/proto-irc/src/lib.rs b/crates/proto-irc/src/lib.rs index 54ff676..3c7ade2 100644 --- a/crates/proto-irc/src/lib.rs +++ b/crates/proto-irc/src/lib.rs @@ -18,8 +18,19 @@ use tokio::io::{AsyncWrite, AsyncWriteExt}; /// Single message tag value. #[derive(Clone, Debug, PartialEq, Eq)] pub struct Tag { - key: Str, - value: Option, + pub key: Str, + pub value: Option, +} + +impl Tag { + pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> { + writer.write_all(self.key.as_bytes()).await?; + if let Some(value) = &self.value { + writer.write_all(b"=").await?; + writer.write_all(value.as_bytes()).await?; + } + Ok(()) + } } fn receiver(input: &str) -> IResult<&str, &str> { diff --git a/crates/proto-irc/src/server.rs b/crates/proto-irc/src/server.rs index c751e23..53cc05d 100644 --- a/crates/proto-irc/src/server.rs +++ b/crates/proto-irc/src/server.rs @@ -19,6 +19,13 @@ pub struct ServerMessage { impl ServerMessage { pub async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> { + if !self.tags.is_empty() { + for tag in &self.tags { + writer.write_all(b"@").await?; + tag.write_async(writer).await?; + writer.write_all(b" ").await?; + } + } match &self.sender { Some(ref sender) => { writer.write_all(b":").await?; From 6c08d69f416d3e2feba4328bd2804291b4f98061 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Tue, 23 Apr 2024 00:41:54 +0200 Subject: [PATCH 22/37] sasl: remove unused code --- crates/sasl/src/lib.rs | 8 -------- 1 file changed, 8 deletions(-) diff --git a/crates/sasl/src/lib.rs b/crates/sasl/src/lib.rs index e00d67b..75f69c5 100644 --- a/crates/sasl/src/lib.rs +++ b/crates/sasl/src/lib.rs @@ -79,10 +79,6 @@ mod test { fn test_fail_if_size_less_then_3() { let orig = b"login\x00pass"; let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody { - login: "login".to_string(), - password: "pass".to_string(), - }; let result = AuthBody::from_str(encoded.as_bytes()); assert!(result.is_err()); @@ -92,10 +88,6 @@ mod test { fn test_fail_if_size_greater_then_3() { let orig = b"first\x00login\x00pass\x00other"; let encoded = general_purpose::STANDARD.encode(orig); - let expected = AuthBody { - login: "login".to_string(), - password: "pass".to_string(), - }; let result = AuthBody::from_str(encoded.as_bytes()); assert!(result.is_err()); From d805061d5bb3fb59e04fbbcf96c4f68460c3de5a Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Tue, 23 Apr 2024 10:10:10 +0000 Subject: [PATCH 23/37] refactor auth logic into a common module (#54) Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/54 --- crates/lavina-core/src/auth.rs | 47 ++++++++++++++++++++++++++++++ crates/lavina-core/src/lib.rs | 1 + crates/lavina-core/src/repo/mod.rs | 4 +-- crates/projection-irc/src/lib.rs | 24 +++++---------- crates/projection-xmpp/src/lib.rs | 30 +++++++------------ src/http.rs | 24 ++++++++------- 6 files changed, 82 insertions(+), 48 deletions(-) create mode 100644 crates/lavina-core/src/auth.rs diff --git a/crates/lavina-core/src/auth.rs b/crates/lavina-core/src/auth.rs new file mode 100644 index 0000000..ba465db --- /dev/null +++ b/crates/lavina-core/src/auth.rs @@ -0,0 +1,47 @@ +use anyhow::Result; + +use crate::prelude::log; +use crate::repo::Storage; + +pub enum Verdict { + Authenticated, + UserNotFound, + InvalidPassword, +} + +pub enum UpdatePasswordResult { + PasswordUpdated, + UserNotFound, +} + +pub struct Authenticator<'a> { + storage: &'a Storage, +} +impl<'a> Authenticator<'a> { + pub fn new(storage: &'a Storage) -> Self { + Self { storage } + } + + pub async fn authenticate(&self, login: &str, provided_password: &str) -> Result { + let Some(stored_user) = self.storage.retrieve_user_by_name(login).await? else { + return Ok(Verdict::UserNotFound); + }; + let Some(expected_password) = stored_user.password else { + log::debug!("Password not defined for user '{}'", login); + return Ok(Verdict::InvalidPassword); + }; + if expected_password == provided_password { + return Ok(Verdict::Authenticated); + } + Ok(Verdict::InvalidPassword) + } + + pub async fn set_password(&self, login: &str, provided_password: &str) -> Result { + let Some(_) = self.storage.retrieve_user_by_name(login).await? else { + return Ok(UpdatePasswordResult::UserNotFound); + }; + self.storage.set_password(login, provided_password).await?; + log::info!("Password changed for player {login}"); + Ok(UpdatePasswordResult::PasswordUpdated) + } +} diff --git a/crates/lavina-core/src/lib.rs b/crates/lavina-core/src/lib.rs index ff52363..e611a01 100644 --- a/crates/lavina-core/src/lib.rs +++ b/crates/lavina-core/src/lib.rs @@ -6,6 +6,7 @@ use crate::player::PlayerRegistry; use crate::repo::Storage; use crate::room::RoomRegistry; +pub mod auth; pub mod player; pub mod prelude; pub mod repo; diff --git a/crates/lavina-core/src/repo/mod.rs b/crates/lavina-core/src/repo/mod.rs index 645c764..714b8cd 100644 --- a/crates/lavina-core/src/repo/mod.rs +++ b/crates/lavina-core/src/repo/mod.rs @@ -38,7 +38,7 @@ impl Storage { Ok(Storage { conn }) } - pub async fn retrieve_user_by_name(&mut self, name: &str) -> Result> { + pub async fn retrieve_user_by_name(&self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( "select u.id, u.name, c.password @@ -136,7 +136,7 @@ impl Storage { Ok(()) } - pub async fn set_password<'a>(&'a mut self, name: &'a str, pwd: &'a str) -> Result> { + pub async fn set_password<'a>(&'a self, name: &'a str, pwd: &'a str) -> Result> { async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result> { let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;") .bind(name) diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index ce450c1..278d456 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -14,6 +14,7 @@ use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::channel; +use lavina_core::auth::{Authenticator, Verdict}; use lavina_core::player::*; use lavina_core::prelude::*; use lavina_core::repo::Storage; @@ -405,24 +406,13 @@ fn sasl_fail_message(sender: Str, nick: Str, text: Str) -> ServerMessage { } async fn auth_user(storage: &mut Storage, login: &str, plain_password: &str) -> Result<()> { - let stored_user = storage.retrieve_user_by_name(login).await?; - - let stored_user = match stored_user { - Some(u) => u, - None => { - log::info!("User '{}' not found", login); - return Err(anyhow!("no user found")); - } - }; - let Some(expected_password) = stored_user.password else { - log::info!("Password not defined for user '{}'", login); - return Err(anyhow!("password is not defined")); - }; - if expected_password != plain_password { - log::info!("Incorrect password supplied for user '{}'", login); - return Err(anyhow!("passwords do not match")); + let verdict = Authenticator::new(storage).authenticate(login, plain_password).await?; + // TODO properly map these onto protocol messages + match verdict { + Verdict::Authenticated => Ok(()), + Verdict::UserNotFound => Err(anyhow!("no user found")), + Verdict::InvalidPassword => Err(anyhow!("incorrect credentials")), } - Ok(()) } async fn handle_registered_socket<'a>( diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 6539254..01f0171 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -9,6 +9,7 @@ use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; +use anyhow::anyhow; use futures_util::future::join_all; use prometheus::Registry as MetricsRegistry; use quick_xml::events::{BytesDecl, Event}; @@ -21,6 +22,7 @@ use tokio::sync::mpsc::channel; use tokio_rustls::rustls::{Certificate, PrivateKey}; use tokio_rustls::TlsAcceptor; +use lavina_core::auth::{Authenticator, Verdict}; use lavina_core::player::{PlayerConnection, PlayerId, PlayerRegistry}; use lavina_core::prelude::*; use lavina_core::repo::Storage; @@ -300,28 +302,18 @@ async fn socket_auth( match AuthBody::from_str(&auth.body) { Ok(logopass) => { let name = &logopass.login; - let stored_user = storage.retrieve_user_by_name(name).await?; - - let stored_user = match stored_user { - Some(u) => u, - None => { - log::info!("User '{}' not found", name); - return Err(fail("no user found")); - } - }; + let verdict = Authenticator::new(storage).authenticate(name, &logopass.password).await?; // TODO return proper XML errors to the client - - if stored_user.password.is_none() { - log::info!("Password not defined for user '{}'", name); - return Err(fail("password is not defined")); + match verdict { + Verdict::Authenticated => {} + Verdict::UserNotFound => { + return Err(anyhow!("no user found")); + } + Verdict::InvalidPassword => { + return Err(anyhow!("incorrect credentials")); + } } - if stored_user.password.as_deref() != Some(&logopass.password) { - log::info!("Incorrect password supplied for user '{}'", name); - return Err(fail("passwords do not match")); - } - let name: Str = name.as_str().into(); - Ok(Authenticated { player_id: PlayerId::from(name.clone())?, xmpp_name: Name(name.clone()), diff --git a/src/http.rs b/src/http.rs index 302bf5f..4bf3ffe 100644 --- a/src/http.rs +++ b/src/http.rs @@ -12,6 +12,7 @@ use prometheus::{Encoder, Registry as MetricsRegistry, TextEncoder}; use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; +use lavina_core::auth::{Authenticator, UpdatePasswordResult}; use lavina_core::prelude::*; use lavina_core::repo::Storage; use lavina_core::room::RoomRegistry; @@ -141,17 +142,20 @@ async fn endpoint_set_password( *response.status_mut() = StatusCode::BAD_REQUEST; return Ok(response); }; - let Some(_) = storage.set_password(&res.player_name, &res.password).await? else { - let payload = ErrorResponse { - code: errors::PLAYER_NOT_FOUND, - message: "No such player exists", + let verdict = Authenticator::new(&storage).set_password(&res.player_name, &res.password).await?; + match verdict { + UpdatePasswordResult::PasswordUpdated => {} + UpdatePasswordResult::UserNotFound => { + let payload = ErrorResponse { + code: errors::PLAYER_NOT_FOUND, + message: "No such player exists", + } + .to_body(); + let mut response = Response::new(payload); + *response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY; + return Ok(response); } - .to_body(); - let mut response = Response::new(payload); - *response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY; - return Ok(response); - }; - log::info!("Password changed for player {}", res.player_name); + } let mut response = Response::new(Full::::default()); *response.status_mut() = StatusCode::NO_CONTENT; Ok(response) From 799da8366c5738d32084d7f043478d386008ef3c Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Tue, 23 Apr 2024 16:26:40 +0000 Subject: [PATCH 24/37] basic dialog implementation with irc and xmpp support (#53) Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/53 --- Cargo.lock | 5 + crates/lavina-core/Cargo.toml | 2 +- crates/lavina-core/migrations/3_dialogs.sql | 17 +++ crates/lavina-core/src/dialog.rs | 150 ++++++++++++++++++++ crates/lavina-core/src/lib.rs | 13 +- crates/lavina-core/src/player.rs | 71 ++++++++- crates/lavina-core/src/repo/dialog.rs | 68 +++++++++ crates/lavina-core/src/repo/mod.rs | 1 + crates/projection-irc/src/lib.rs | 30 ++++ crates/projection-irc/tests/lib.rs | 60 ++++++++ crates/projection-xmpp/src/message.rs | 4 + crates/projection-xmpp/src/updates.rs | 28 ++++ 12 files changed, 443 insertions(+), 6 deletions(-) create mode 100644 crates/lavina-core/migrations/3_dialogs.sql create mode 100644 crates/lavina-core/src/dialog.rs create mode 100644 crates/lavina-core/src/repo/dialog.rs diff --git a/Cargo.lock b/Cargo.lock index a586d14..a909a9a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1775,6 +1775,7 @@ dependencies = [ "atoi", "byteorder", "bytes", + "chrono", "crc", "crossbeam-queue", "either", @@ -1833,6 +1834,7 @@ dependencies = [ "sha2", "sqlx-core", "sqlx-mysql", + "sqlx-postgres", "sqlx-sqlite", "syn 1.0.109", "tempfile", @@ -1850,6 +1852,7 @@ dependencies = [ "bitflags 2.5.0", "byteorder", "bytes", + "chrono", "crc", "digest", "dotenvy", @@ -1891,6 +1894,7 @@ dependencies = [ "base64 0.21.7", "bitflags 2.5.0", "byteorder", + "chrono", "crc", "dotenvy", "etcetera", @@ -1926,6 +1930,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa" dependencies = [ "atoi", + "chrono", "flume", "futures-channel", "futures-core", diff --git a/crates/lavina-core/Cargo.toml b/crates/lavina-core/Cargo.toml index 92bf798..c49f83d 100644 --- a/crates/lavina-core/Cargo.toml +++ b/crates/lavina-core/Cargo.toml @@ -5,7 +5,7 @@ version.workspace = true [dependencies] anyhow.workspace = true -sqlx = { version = "0.7.4", features = ["sqlite", "migrate"] } +sqlx = { version = "0.7.4", features = ["sqlite", "migrate", "chrono"] } serde.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/crates/lavina-core/migrations/3_dialogs.sql b/crates/lavina-core/migrations/3_dialogs.sql new file mode 100644 index 0000000..599b306 --- /dev/null +++ b/crates/lavina-core/migrations/3_dialogs.sql @@ -0,0 +1,17 @@ +create table dialogs( + id integer primary key autoincrement not null, + participant_1 integer not null, + participant_2 integer not null, + created_at timestamp not null, + message_count integer not null default 0, + unique (participant_1, participant_2) +); + +create table dialog_messages( + dialog_id integer not null, + id integer not null, -- unique per dialog, sequential in one dialog + author_id integer not null, + content string not null, + created_at timestamp not null, + primary key (dialog_id, id) +); diff --git a/crates/lavina-core/src/dialog.rs b/crates/lavina-core/src/dialog.rs new file mode 100644 index 0000000..66fe8b5 --- /dev/null +++ b/crates/lavina-core/src/dialog.rs @@ -0,0 +1,150 @@ +//! Domain of dialogs – conversations between two participants. +//! +//! Dialogs are different from rooms in that they are always between two participants. +//! There are no admins or other roles in dialogs, both participants have equal rights. + +use std::collections::HashMap; +use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use tokio::sync::RwLock as AsyncRwLock; + +use crate::player::{PlayerId, PlayerRegistry, Updates}; +use crate::prelude::*; +use crate::repo::Storage; + +/// Id of a conversation between two players. +/// +/// Dialogs are identified by the pair of participants' ids. The order of ids does not matter. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct DialogId(PlayerId, PlayerId); +impl DialogId { + pub fn new(a: PlayerId, b: PlayerId) -> DialogId { + if a.as_inner() < b.as_inner() { + DialogId(a, b) + } else { + DialogId(b, a) + } + } + + pub fn as_inner(&self) -> (&PlayerId, &PlayerId) { + (&self.0, &self.1) + } + + pub fn into_inner(self) -> (PlayerId, PlayerId) { + (self.0, self.1) + } +} + +struct Dialog { + storage_id: u32, + player_storage_id_1: u32, + player_storage_id_2: u32, + message_count: u32, +} + +struct DialogRegistryInner { + dialogs: HashMap>, + players: Option, + storage: Storage, +} + +#[derive(Clone)] +pub struct DialogRegistry(Arc>); + +impl DialogRegistry { + pub async fn send_message( + &self, + from: PlayerId, + to: PlayerId, + body: Str, + created_at: &DateTime, + ) -> Result<()> { + let mut guard = self.0.read().await; + let id = DialogId::new(from.clone(), to.clone()); + let dialog = guard.dialogs.get(&id); + if let Some(d) = dialog { + let mut d = d.write().await; + guard.storage.increment_dialog_message_count(d.storage_id).await?; + d.message_count += 1; + } else { + drop(guard); + let mut guard2 = self.0.write().await; + // double check in case concurrent access has loaded this dialog + if let Some(d) = guard2.dialogs.get(&id) { + let mut d = d.write().await; + guard2.storage.increment_dialog_message_count(d.storage_id).await?; + d.message_count += 1; + } else { + let (p1, p2) = id.as_inner(); + tracing::info!("Dialog {id:?} not found locally, trying to load from storage"); + let stored_dialog = match guard2.storage.retrieve_dialog(p1.as_inner(), p2.as_inner()).await? { + Some(t) => t, + None => { + tracing::info!("Dialog {id:?} does not exist, creating a new one in storage"); + guard2.storage.initialize_dialog(p1.as_inner(), p2.as_inner(), created_at).await? + } + }; + tracing::info!("Dialog {id:?} loaded"); + guard2.storage.increment_dialog_message_count(stored_dialog.id).await?; + let dialog = Dialog { + storage_id: stored_dialog.id, + player_storage_id_1: stored_dialog.participant_1, + player_storage_id_2: stored_dialog.participant_2, + message_count: stored_dialog.message_count + 1, + }; + guard2.dialogs.insert(id.clone(), AsyncRwLock::new(dialog)); + } + guard = guard2.downgrade(); + } + // TODO send message to the other player and persist it + let Some(players) = &guard.players else { + tracing::error!("No player registry present"); + return Ok(()); + }; + let Some(player) = players.get_player(&to).await else { + tracing::debug!("Player {to:?} not active, not sending message"); + return Ok(()); + }; + let update = Updates::NewDialogMessage { + sender: from.clone(), + receiver: to.clone(), + body: body.clone(), + created_at: created_at.clone(), + }; + player.update(update).await; + return Ok(()); + } +} + +impl DialogRegistry { + pub fn new(storage: Storage) -> DialogRegistry { + DialogRegistry(Arc::new(AsyncRwLock::new(DialogRegistryInner { + dialogs: HashMap::new(), + players: None, + storage, + }))) + } + + pub async fn set_players(&self, players: PlayerRegistry) { + let mut guard = self.0.write().await; + guard.players = Some(players); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dialog_id_new() { + let a = PlayerId::from("a").unwrap(); + let b = PlayerId::from("b").unwrap(); + let id1 = DialogId::new(a.clone(), b.clone()); + let id2 = DialogId::new(a.clone(), b.clone()); + // Dialog ids are invariant with respect to the order of participants + assert_eq!(id1, id2); + assert_eq!(id1.as_inner(), (&a, &b)); + assert_eq!(id2.as_inner(), (&a, &b)); + } +} diff --git a/crates/lavina-core/src/lib.rs b/crates/lavina-core/src/lib.rs index e611a01..1128c61 100644 --- a/crates/lavina-core/src/lib.rs +++ b/crates/lavina-core/src/lib.rs @@ -2,11 +2,13 @@ use anyhow::Result; use prometheus::Registry as MetricsRegistry; +use crate::dialog::DialogRegistry; use crate::player::PlayerRegistry; use crate::repo::Storage; use crate::room::RoomRegistry; pub mod auth; +pub mod dialog; pub mod player; pub mod prelude; pub mod repo; @@ -19,14 +21,21 @@ mod table; pub struct LavinaCore { pub players: PlayerRegistry, pub rooms: RoomRegistry, + pub dialogs: DialogRegistry, } impl LavinaCore { pub async fn new(mut metrics: MetricsRegistry, storage: Storage) -> Result { // TODO shutdown all services in reverse order on error let rooms = RoomRegistry::new(&mut metrics, storage.clone())?; - let players = PlayerRegistry::empty(rooms.clone(), storage.clone(), &mut metrics)?; - Ok(LavinaCore { players, rooms }) + let dialogs = DialogRegistry::new(storage.clone()); + let players = PlayerRegistry::empty(rooms.clone(), dialogs.clone(), storage.clone(), &mut metrics)?; + dialogs.set_players(players.clone()).await; + Ok(LavinaCore { + players, + rooms, + dialogs, + }) } pub async fn shutdown(mut self) -> Result<()> { diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 9925709..3a58812 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -16,6 +16,7 @@ use serde::Serialize; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::RwLock; +use crate::dialog::DialogRegistry; use crate::prelude::*; use crate::repo::Storage; use crate::room::{RoomHandle, RoomId, RoomInfo, RoomRegistry}; @@ -104,6 +105,18 @@ impl PlayerConnection { self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; Ok(deferred.await?) } + + /// Handler in [Player::send_dialog_message]. + pub async fn send_dialog_message(&self, recipient: PlayerId, body: Str) -> Result<()> { + let (promise, deferred) = oneshot(); + let cmd = ClientCommand::SendDialogMessage { + recipient, + body, + promise, + }; + self.player_handle.send(ActorCommand::ClientCommand(cmd, self.connection_id.clone())).await; + Ok(deferred.await?) + } } /// Handle to a player actor. @@ -174,6 +187,11 @@ pub enum ClientCommand { GetRooms { promise: Promise>, }, + SendDialogMessage { + recipient: PlayerId, + body: Str, + promise: Promise<()>, + }, } pub enum JoinResult { @@ -210,6 +228,12 @@ pub enum Updates { }, /// The player was banned from the room and left it immediately. BannedFrom(RoomId), + NewDialogMessage { + sender: PlayerId, + receiver: PlayerId, + body: Str, + created_at: DateTime, + }, } /// Handle to a player registry — a shared data structure containing information about players. @@ -218,6 +242,7 @@ pub struct PlayerRegistry(Arc>); impl PlayerRegistry { pub fn empty( room_registry: RoomRegistry, + dialogs: DialogRegistry, storage: Storage, metrics: &mut MetricsRegistry, ) -> Result { @@ -225,6 +250,7 @@ impl PlayerRegistry { metrics.register(Box::new(metric_active_players.clone()))?; let inner = PlayerRegistryInner { room_registry, + dialogs, storage, players: HashMap::new(), metric_active_players, @@ -232,12 +258,23 @@ impl PlayerRegistry { Ok(PlayerRegistry(Arc::new(RwLock::new(inner)))) } + pub async fn get_player(&self, id: &PlayerId) -> Option { + let inner = self.0.read().await; + inner.players.get(id).map(|(handle, _)| handle.clone()) + } + pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle { let mut inner = self.0.write().await; if let Some((handle, _)) = inner.players.get(id) { handle.clone() } else { - let (handle, fiber) = Player::launch(id.clone(), inner.room_registry.clone(), inner.storage.clone()).await; + let (handle, fiber) = Player::launch( + id.clone(), + inner.room_registry.clone(), + inner.dialogs.clone(), + inner.storage.clone(), + ) + .await; inner.players.insert(id.clone(), (handle.clone(), fiber)); inner.metric_active_players.inc(); handle @@ -265,6 +302,7 @@ impl PlayerRegistry { /// The player registry state representation. struct PlayerRegistryInner { room_registry: RoomRegistry, + dialogs: DialogRegistry, storage: Storage, /// Active player actors. players: HashMap)>, @@ -281,10 +319,16 @@ struct Player { rx: Receiver, handle: PlayerHandle, rooms: RoomRegistry, + dialogs: DialogRegistry, storage: Storage, } impl Player { - async fn launch(player_id: PlayerId, rooms: RoomRegistry, storage: Storage) -> (PlayerHandle, JoinHandle) { + async fn launch( + player_id: PlayerId, + rooms: RoomRegistry, + dialogs: DialogRegistry, + storage: Storage, + ) -> (PlayerHandle, JoinHandle) { let (tx, rx) = channel(32); let handle = PlayerHandle { tx }; let handle_clone = handle.clone(); @@ -301,6 +345,7 @@ impl Player { rx, handle, rooms, + dialogs, storage, }; let fiber = tokio::task::spawn(player.main_loop()); @@ -340,7 +385,7 @@ impl Player { /// Handle an incoming update by changing the internal state and broadcasting it to all connections if necessary. async fn handle_update(&mut self, update: Updates) { - log::info!( + log::debug!( "Player received an update, broadcasting to {} connections", self.connections.len() ); @@ -389,6 +434,14 @@ impl Player { let result = self.get_rooms().await; let _ = promise.send(result); } + ClientCommand::SendDialogMessage { + recipient, + body, + promise, + } => { + self.send_dialog_message(connection_id, recipient, body).await; + let _ = promise.send(()); + } } } @@ -467,6 +520,18 @@ impl Player { response } + async fn send_dialog_message(&self, connection_id: ConnectionId, recipient: PlayerId, body: Str) { + let created_at = chrono::Utc::now(); + self.dialogs.send_message(self.player_id.clone(), recipient.clone(), body.clone(), &created_at).await.unwrap(); + let update = Updates::NewDialogMessage { + sender: self.player_id.clone(), + receiver: recipient.clone(), + body, + created_at, + }; + self.broadcast_update(update, connection_id).await; + } + /// Broadcasts an update to all connections except the one with the given id. /// /// This is called after handling a client command. diff --git a/crates/lavina-core/src/repo/dialog.rs b/crates/lavina-core/src/repo/dialog.rs new file mode 100644 index 0000000..cbe3161 --- /dev/null +++ b/crates/lavina-core/src/repo/dialog.rs @@ -0,0 +1,68 @@ +use anyhow::Result; +use chrono::{DateTime, Utc}; +use sqlx::FromRow; + +use crate::repo::Storage; + +impl Storage { + pub async fn retrieve_dialog(&self, participant_1: &str, participant_2: &str) -> Result> { + let mut executor = self.conn.lock().await; + let res = sqlx::query_as( + "select r.id, r.participant_1, r.participant_2, r.message_count + from dialogs r join users u1 on r.participant_1 = u1.id join users u2 on r.participant_2 = u2.id + where u1.name = ? and u2.name = ?;", + ) + .bind(participant_1) + .bind(participant_2) + .fetch_optional(&mut *executor) + .await?; + + Ok(res) + } + + pub async fn increment_dialog_message_count(&self, storage_id: u32) -> Result<()> { + let mut executor = self.conn.lock().await; + sqlx::query( + "update rooms set message_count = message_count + 1 + where id = ?;", + ) + .bind(storage_id) + .execute(&mut *executor) + .await?; + + Ok(()) + } + + pub async fn initialize_dialog( + &self, + participant_1: &str, + participant_2: &str, + created_at: &DateTime, + ) -> Result { + let mut executor = self.conn.lock().await; + let res: StoredDialog = sqlx::query_as( + "insert into dialogs(participant_1, participant_2, created_at) + values ( + (select id from users where name = ?), + (select id from users where name = ?), + ? + ) + returning id, participant_1, participant_2, message_count;", + ) + .bind(participant_1) + .bind(participant_2) + .bind(&created_at) + .fetch_one(&mut *executor) + .await?; + + Ok(res) + } +} + +#[derive(FromRow)] +pub struct StoredDialog { + pub id: u32, + pub participant_1: u32, + pub participant_2: u32, + pub message_count: u32, +} diff --git a/crates/lavina-core/src/repo/mod.rs b/crates/lavina-core/src/repo/mod.rs index 714b8cd..dfa93c6 100644 --- a/crates/lavina-core/src/repo/mod.rs +++ b/crates/lavina-core/src/repo/mod.rs @@ -12,6 +12,7 @@ use tokio::sync::Mutex; use crate::prelude::*; +mod dialog; mod room; mod user; diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index 278d456..342682a 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -643,6 +643,32 @@ async fn handle_update( .await?; writer.flush().await? } + Updates::NewDialogMessage { + sender, + receiver, + body, + created_at, + } => { + let mut tags = vec![]; + if user.enabled_capabilities.contains(Capabilities::ServerTime) { + let tag = Tag { + key: "time".into(), + value: Some(created_at.to_rfc3339_opts(SecondsFormat::Millis, true).into()), + }; + tags.push(tag); + } + ServerMessage { + tags, + sender: Some(sender.as_inner().clone()), + body: ServerMessageBody::PrivateMessage { + target: Recipient::Nick(receiver.as_inner().clone()), + body: body.clone(), + }, + } + .write_async(writer) + .await?; + writer.flush().await? + } } Ok(()) } @@ -689,6 +715,10 @@ async fn handle_incoming_message( let room_id = RoomId::from(chan)?; user_handle.send_message(room_id, body).await?; } + Recipient::Nick(nick) => { + let receiver = PlayerId::from(nick)?; + user_handle.send_dialog_message(receiver, body).await?; + } _ => log::warn!("Unsupported target type"), }, ClientMessage::Topic { chan, topic } => { diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 5a4eb7c..6a90c46 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -635,3 +635,63 @@ async fn server_time_capability() -> Result<()> { server.server.terminate().await?; Ok(()) } + +#[tokio::test] +async fn scenario_two_players_dialog() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester1").await?; + server.storage.set_password("tester1", "password").await?; + server.storage.create_user("tester2").await?; + server.storage.set_password("tester2", "password").await?; + + let mut stream1 = TcpStream::connect(server.server.addr).await?; + let mut s1 = TestScope::new(&mut stream1); + + let mut stream2 = TcpStream::connect(server.server.addr).await?; + let mut s2 = TestScope::new(&mut stream2); + + s1.send("CAP LS 302").await?; + s1.send("NICK tester1").await?; + s1.send("USER UserName 0 * :Real Name").await?; + s1.expect_cap_ls().await?; + s1.send("CAP REQ :sasl").await?; + s1.expect(":testserver CAP tester1 ACK :sasl").await?; + s1.send("AUTHENTICATE PLAIN").await?; + s1.expect(":testserver AUTHENTICATE +").await?; + s1.send("AUTHENTICATE dGVzdGVyMQB0ZXN0ZXIxAHBhc3N3b3Jk").await?; // base64-encoded 'tester1\x00tester1\x00password' + s1.expect(":testserver 900 tester1 tester1 tester1 :You are now logged in as tester1").await?; + s1.expect(":testserver 903 tester1 :SASL authentication successful").await?; + s1.send("CAP END").await?; + s1.expect_server_introduction("tester1").await?; + s1.expect_nothing().await?; + + s2.send("CAP LS 302").await?; + s2.send("NICK tester2").await?; + s2.send("USER UserName 0 * :Real Name").await?; + s2.expect_cap_ls().await?; + s2.send("CAP REQ :sasl").await?; + s2.expect(":testserver CAP tester2 ACK :sasl").await?; + s2.send("AUTHENTICATE PLAIN").await?; + s2.expect(":testserver AUTHENTICATE +").await?; + s2.send("AUTHENTICATE dGVzdGVyMgB0ZXN0ZXIyAHBhc3N3b3Jk").await?; // base64-encoded 'tester2\x00tester2\x00password' + s2.expect(":testserver 900 tester2 tester2 tester2 :You are now logged in as tester2").await?; + s2.expect(":testserver 903 tester2 :SASL authentication successful").await?; + s2.send("CAP END").await?; + s2.expect_server_introduction("tester2").await?; + s2.expect_nothing().await?; + + s1.send("PRIVMSG tester2 :Henlo! How are you?").await?; + s1.expect_nothing().await?; + s2.expect(":tester1 PRIVMSG tester2 :Henlo! How are you?").await?; + s2.expect_nothing().await?; + + s2.send("PRIVMSG tester1 good").await?; + s2.expect_nothing().await?; + s1.expect(":tester2 PRIVMSG tester1 :good").await?; + s1.expect_nothing().await?; + + Ok(()) +} diff --git a/crates/projection-xmpp/src/message.rs b/crates/projection-xmpp/src/message.rs index a737b2b..15a3e0d 100644 --- a/crates/projection-xmpp/src/message.rs +++ b/crates/projection-xmpp/src/message.rs @@ -1,5 +1,6 @@ //! Handling of all client2server message stanzas +use lavina_core::player::PlayerId; use quick_xml::events::Event; use lavina_core::prelude::*; @@ -40,6 +41,9 @@ impl<'a> XmppConnection<'a> { } .serialize(output); Ok(()) + } else if server.0.as_ref() == &*self.hostname && m.r#type == MessageType::Chat { + self.user_handle.send_dialog_message(PlayerId::from(name.0.clone())?, m.body.clone()).await?; + Ok(()) } else { todo!() } diff --git a/crates/projection-xmpp/src/updates.rs b/crates/projection-xmpp/src/updates.rs index fcc62b6..d659467 100644 --- a/crates/projection-xmpp/src/updates.rs +++ b/crates/projection-xmpp/src/updates.rs @@ -39,6 +39,34 @@ impl<'a> XmppConnection<'a> { } .serialize(output); } + Updates::NewDialogMessage { + sender, + receiver, + body, + created_at: _, + } => { + if receiver == self.user.player_id { + Message::<()> { + to: Some(Jid { + name: Some(self.user.xmpp_name.clone()), + server: Server(self.hostname.clone()), + resource: Some(self.user.xmpp_resource.clone()), + }), + from: Some(Jid { + name: Some(Name(sender.as_inner().clone())), + server: Server(self.hostname.clone()), + resource: Some(Resource(sender.into_inner())), + }), + id: None, + r#type: MessageType::Chat, + lang: None, + subject: None, + body: body.into(), + custom: vec![], + } + .serialize(output); + } + } _ => {} } Ok(()) From d305f5bf776d902997dfed3656c5dbe430940736 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Tue, 23 Apr 2024 16:31:00 +0000 Subject: [PATCH 25/37] argon2-based password hashing (#55) Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/55 --- Cargo.lock | 34 ++++++++++++++++++ crates/lavina-core/Cargo.toml | 2 ++ .../migrations/4_new_challenges.sql | 4 +++ crates/lavina-core/src/auth.rs | 35 ++++++++++++++----- crates/lavina-core/src/player.rs | 28 +++++++++------ crates/lavina-core/src/repo/auth.rs | 19 ++++++++++ crates/lavina-core/src/repo/mod.rs | 5 ++- crates/projection-irc/tests/lib.rs | 25 ++++++------- crates/projection-xmpp/tests/lib.rs | 7 ++-- 9 files changed, 123 insertions(+), 36 deletions(-) create mode 100644 crates/lavina-core/migrations/4_new_challenges.sql create mode 100644 crates/lavina-core/src/repo/auth.rs diff --git a/Cargo.lock b/Cargo.lock index a909a9a..9658ff6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,18 @@ version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" +[[package]] +name = "argon2" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" +dependencies = [ + "base64ct", + "blake2", + "cpufeatures", + "password-hash", +] + [[package]] name = "assert_matches" version = "1.5.0" @@ -192,6 +204,15 @@ dependencies = [ "serde", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -882,8 +903,10 @@ name = "lavina-core" version = "0.0.2-dev" dependencies = [ "anyhow", + "argon2", "chrono", "prometheus", + "rand_core", "serde", "sqlx", "tokio", @@ -1126,6 +1149,17 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "password-hash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "paste" version = "1.0.14" diff --git a/crates/lavina-core/Cargo.toml b/crates/lavina-core/Cargo.toml index c49f83d..ab26daf 100644 --- a/crates/lavina-core/Cargo.toml +++ b/crates/lavina-core/Cargo.toml @@ -11,3 +11,5 @@ tokio.workspace = true tracing.workspace = true prometheus.workspace = true chrono.workspace = true +argon2 = { version = "0.5.3" } +rand_core = { version = "0.6.4", features = ["getrandom"] } diff --git a/crates/lavina-core/migrations/4_new_challenges.sql b/crates/lavina-core/migrations/4_new_challenges.sql new file mode 100644 index 0000000..9017511 --- /dev/null +++ b/crates/lavina-core/migrations/4_new_challenges.sql @@ -0,0 +1,4 @@ +create table challenges_argon2_password( + user_id integer primary key not null, + hash string not null +); diff --git a/crates/lavina-core/src/auth.rs b/crates/lavina-core/src/auth.rs index ba465db..ccf962d 100644 --- a/crates/lavina-core/src/auth.rs +++ b/crates/lavina-core/src/auth.rs @@ -1,4 +1,7 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; +use argon2::password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString}; +use argon2::Argon2; +use rand_core::OsRng; use crate::prelude::log; use crate::repo::Storage; @@ -26,21 +29,35 @@ impl<'a> Authenticator<'a> { let Some(stored_user) = self.storage.retrieve_user_by_name(login).await? else { return Ok(Verdict::UserNotFound); }; - let Some(expected_password) = stored_user.password else { - log::debug!("Password not defined for user '{}'", login); - return Ok(Verdict::InvalidPassword); - }; - if expected_password == provided_password { - return Ok(Verdict::Authenticated); + if let Some(argon2_hash) = stored_user.argon2_hash { + let argon2 = Argon2::default(); + let password_hash = + PasswordHash::new(&argon2_hash).map_err(|e| anyhow!("Failed to parse password hash: {e:?}"))?; + let password_verifier = argon2.verify_password(provided_password.as_bytes(), &password_hash); + if password_verifier.is_ok() { + return Ok(Verdict::Authenticated); + } + } + if let Some(expected_password) = stored_user.password { + if expected_password == provided_password { + return Ok(Verdict::Authenticated); + } } Ok(Verdict::InvalidPassword) } pub async fn set_password(&self, login: &str, provided_password: &str) -> Result { - let Some(_) = self.storage.retrieve_user_by_name(login).await? else { + let Some(u) = self.storage.retrieve_user_by_name(login).await? else { return Ok(UpdatePasswordResult::UserNotFound); }; - self.storage.set_password(login, provided_password).await?; + + let salt = SaltString::generate(&mut OsRng); + let argon2 = Argon2::default(); + let password_hash = argon2 + .hash_password(provided_password.as_bytes(), &salt) + .map_err(|e| anyhow!("Failed to hash password: {e:?}"))?; + + self.storage.set_argon2_challenge(u.id, password_hash.to_string().as_str()).await?; log::info!("Password changed for player {login}"); Ok(UpdatePasswordResult::PasswordUpdated) } diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 3a58812..4d6f6cb 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -264,20 +264,26 @@ impl PlayerRegistry { } pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle { - let mut inner = self.0.write().await; + let inner = self.0.read().await; if let Some((handle, _)) = inner.players.get(id) { handle.clone() } else { - let (handle, fiber) = Player::launch( - id.clone(), - inner.room_registry.clone(), - inner.dialogs.clone(), - inner.storage.clone(), - ) - .await; - inner.players.insert(id.clone(), (handle.clone(), fiber)); - inner.metric_active_players.inc(); - handle + drop(inner); + let mut inner = self.0.write().await; + if let Some((handle, _)) = inner.players.get(id) { + handle.clone() + } else { + let (handle, fiber) = Player::launch( + id.clone(), + inner.room_registry.clone(), + inner.dialogs.clone(), + inner.storage.clone(), + ) + .await; + inner.players.insert(id.clone(), (handle.clone(), fiber)); + inner.metric_active_players.inc(); + handle + } } } diff --git a/crates/lavina-core/src/repo/auth.rs b/crates/lavina-core/src/repo/auth.rs new file mode 100644 index 0000000..f7c0d69 --- /dev/null +++ b/crates/lavina-core/src/repo/auth.rs @@ -0,0 +1,19 @@ +use anyhow::Result; + +use crate::repo::Storage; + +impl Storage { + pub async fn set_argon2_challenge(&self, user_id: u32, hash: &str) -> Result<()> { + let mut executor = self.conn.lock().await; + sqlx::query( + "insert into challenges_argon2_password(user_id, hash) + values (?, ?) + on conflict(user_id) do update set hash = excluded.hash;", + ) + .bind(user_id) + .bind(hash) + .execute(&mut *executor) + .await?; + Ok(()) + } +} diff --git a/crates/lavina-core/src/repo/mod.rs b/crates/lavina-core/src/repo/mod.rs index dfa93c6..9c7aff6 100644 --- a/crates/lavina-core/src/repo/mod.rs +++ b/crates/lavina-core/src/repo/mod.rs @@ -12,6 +12,7 @@ use tokio::sync::Mutex; use crate::prelude::*; +mod auth; mod dialog; mod room; mod user; @@ -42,8 +43,9 @@ impl Storage { pub async fn retrieve_user_by_name(&self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( - "select u.id, u.name, c.password + "select u.id, u.name, c.password, a.hash as argon2_hash from users u left join challenges_plain_password c on u.id = c.user_id + left join challenges_argon2_password a on u.id = a.user_id where u.name = ?;", ) .bind(name) @@ -175,6 +177,7 @@ pub struct StoredUser { pub id: u32, pub name: String, pub password: Option, + pub argon2_hash: Option>, } #[derive(FromRow)] diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 6a90c46..2de4f9e 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -8,6 +8,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::net::tcp::{ReadHalf, WriteHalf}; use tokio::net::TcpStream; +use lavina_core::auth::Authenticator; use lavina_core::player::{JoinResult, PlayerId, SendMessageResult}; use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::room::RoomId; @@ -27,7 +28,7 @@ impl<'a> TestScope<'a> { let (reader, writer) = stream.split(); let reader = BufReader::new(reader); let buffer = vec![]; - let timeout = Duration::from_millis(100); + let timeout = Duration::from_millis(1000); TestScope { reader, writer, @@ -159,7 +160,7 @@ async fn scenario_basic() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); @@ -188,7 +189,7 @@ async fn scenario_join_and_reboot() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); @@ -258,7 +259,7 @@ async fn scenario_force_join_msg() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream1 = TcpStream::connect(server.server.addr).await?; let mut s1 = TestScope::new(&mut stream1); @@ -324,9 +325,9 @@ async fn scenario_two_users() -> Result<()> { // test scenario server.storage.create_user("tester1").await?; - server.storage.set_password("tester1", "password").await?; + Authenticator::new(&server.storage).set_password("tester1", "password").await?; server.storage.create_user("tester2").await?; - server.storage.set_password("tester2", "password").await?; + Authenticator::new(&server.storage).set_password("tester2", "password").await?; let mut stream1 = TcpStream::connect(server.server.addr).await?; let mut s1 = TestScope::new(&mut stream1); @@ -388,7 +389,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); @@ -428,7 +429,7 @@ async fn scenario_cap_full_negotiation_nick_last() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); @@ -467,7 +468,7 @@ async fn scenario_cap_short_negotiation() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); @@ -505,7 +506,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); @@ -549,7 +550,7 @@ async fn terminate_socket_scenario() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); @@ -574,7 +575,7 @@ async fn server_time_capability() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index be687a4..bece5d9 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -16,6 +16,7 @@ use tokio_rustls::rustls::client::ServerCertVerifier; use tokio_rustls::rustls::{ClientConfig, ServerName}; use tokio_rustls::TlsConnector; +use lavina_core::auth::Authenticator; use lavina_core::repo::{Storage, StorageConfig}; use lavina_core::LavinaCore; use projection_xmpp::{launch, RunningServer, ServerConfig}; @@ -158,7 +159,7 @@ async fn scenario_basic() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); @@ -210,7 +211,7 @@ async fn scenario_basic_without_headers() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); @@ -260,7 +261,7 @@ async fn terminate_socket() -> Result<()> { // test scenario server.storage.create_user("tester").await?; - server.storage.set_password("tester", "password").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); From ec49489ef17861ece8530caa643398765b15f29d Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Tue, 23 Apr 2024 19:14:46 +0200 Subject: [PATCH 26/37] validate that rooms and dialogs are owned exclusively on shutdown --- crates/lavina-core/src/dialog.rs | 15 +++++++++++++ crates/lavina-core/src/lib.rs | 6 +++-- crates/lavina-core/src/player.rs | 10 +++++++++ crates/lavina-core/src/room.rs | 11 +++++++++ crates/projection-irc/tests/lib.rs | 35 +++++++++++++++++++++-------- crates/projection-xmpp/tests/lib.rs | 13 ++++++++--- 6 files changed, 76 insertions(+), 14 deletions(-) diff --git a/crates/lavina-core/src/dialog.rs b/crates/lavina-core/src/dialog.rs index 66fe8b5..f87294c 100644 --- a/crates/lavina-core/src/dialog.rs +++ b/crates/lavina-core/src/dialog.rs @@ -130,6 +130,21 @@ impl DialogRegistry { let mut guard = self.0.write().await; guard.players = Some(players); } + + pub async fn unset_players(&self) { + let mut guard = self.0.write().await; + guard.players = None; + } + + pub fn shutdown(self) -> Result<()> { + let res = match Arc::try_unwrap(self.0) { + Ok(e) => e, + Err(_) => return Err(fail("failed to acquire dialogs ownership on shutdown")), + }; + let res = res.into_inner(); + drop(res); + Ok(()) + } } #[cfg(test)] diff --git a/crates/lavina-core/src/lib.rs b/crates/lavina-core/src/lib.rs index 1128c61..b251ed9 100644 --- a/crates/lavina-core/src/lib.rs +++ b/crates/lavina-core/src/lib.rs @@ -40,8 +40,10 @@ impl LavinaCore { pub async fn shutdown(mut self) -> Result<()> { self.players.shutdown_all().await?; - drop(self.players); - drop(self.rooms); + self.dialogs.unset_players().await; + self.players.shutdown()?; + self.dialogs.shutdown()?; + self.rooms.shutdown()?; Ok(()) } } diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 4d6f6cb..2f5edb4 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -258,6 +258,16 @@ impl PlayerRegistry { Ok(PlayerRegistry(Arc::new(RwLock::new(inner)))) } + pub fn shutdown(self) -> Result<()> { + let res = match Arc::try_unwrap(self.0) { + Ok(e) => e, + Err(_) => return Err(fail("failed to acquire players ownership on shutdown")), + }; + let res = res.into_inner(); + drop(res); + Ok(()) + } + pub async fn get_player(&self, id: &PlayerId) -> Option { let inner = self.0.read().await; inner.players.get(id).map(|(handle, _)| handle.clone()) diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index 52ac7c4..d50e169 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -48,6 +48,17 @@ impl RoomRegistry { Ok(RoomRegistry(Arc::new(AsyncRwLock::new(inner)))) } + pub fn shutdown(self) -> Result<()> { + let res = match Arc::try_unwrap(self.0) { + Ok(e) => e, + Err(_) => return Err(fail("failed to acquire rooms ownership on shutdown")), + }; + let res = res.into_inner(); + // TODO drop all rooms + drop(res); + Ok(()) + } + pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result { let mut inner = self.0.write().await; if let Some(room_handle) = inner.get_or_load_room(&room_id).await? { diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 2de4f9e..2b069ef 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -151,6 +151,13 @@ impl TestServer { server, }) } + + async fn shutdown(self) -> Result<()> { + self.server.terminate().await?; + self.core.shutdown().await?; + self.storage.close().await?; + Ok(()) + } } #[tokio::test] @@ -178,7 +185,7 @@ async fn scenario_basic() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -248,7 +255,7 @@ async fn scenario_join_and_reboot() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -314,7 +321,7 @@ async fn scenario_force_join_msg() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -375,6 +382,11 @@ async fn scenario_two_users() -> Result<()> { s1.expect(":tester1 PART #test").await?; // The second user should receive the PART message s2.expect(":tester1 PART #test").await?; + + stream1.shutdown().await?; + stream2.shutdown().await?; + + server.shutdown().await?; Ok(()) } @@ -418,7 +430,7 @@ async fn scenario_cap_full_negotiation() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -457,7 +469,7 @@ async fn scenario_cap_full_negotiation_nick_last() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -495,7 +507,7 @@ async fn scenario_cap_short_negotiation() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -539,7 +551,7 @@ async fn scenario_cap_sasl_fail() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -562,7 +574,7 @@ async fn terminate_socket_scenario() -> Result<()> { s.send("AUTHENTICATE PLAIN").await?; s.expect(":testserver AUTHENTICATE +").await?; - server.server.terminate().await?; + server.shutdown().await?; assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof); Ok(()) @@ -633,7 +645,7 @@ async fn server_time_capability() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -694,5 +706,10 @@ async fn scenario_two_players_dialog() -> Result<()> { s1.expect(":tester2 PRIVMSG tester1 :good").await?; s1.expect_nothing().await?; + stream1.shutdown().await?; + stream2.shutdown().await?; + + server.shutdown().await?; + Ok(()) } diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index bece5d9..88af83d 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -150,6 +150,13 @@ impl TestServer { server, }) } + + async fn shutdown(self) -> Result<()> { + self.server.terminate().await?; + self.core.shutdown().await?; + self.storage.close().await?; + Ok(()) + } } #[tokio::test] @@ -200,7 +207,7 @@ async fn scenario_basic() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -250,7 +257,7 @@ async fn scenario_basic_without_headers() -> Result<()> { // wrap up - server.server.terminate().await?; + server.shutdown().await?; Ok(()) } @@ -291,7 +298,7 @@ async fn terminate_socket() -> Result<()> { let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; tracing::info!("TLS connection established"); - server.server.terminate().await?; + server.shutdown().await?; assert_eq!(stream.read_u8().await.unwrap_err().kind(), ErrorKind::UnexpectedEof); From 4ff09ea05f90db2bf320cbd5af29d17223da3f8f Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Fri, 26 Apr 2024 10:16:23 +0000 Subject: [PATCH 27/37] tracing otlp exporter and instrumentation annotations (#57) Resolves #56 Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/57 --- Cargo.lock | 435 +++++++++++++++++++++++++++- Cargo.toml | 5 + crates/lavina-core/src/auth.rs | 2 + crates/lavina-core/src/player.rs | 64 +++- crates/lavina-core/src/repo/auth.rs | 1 + crates/lavina-core/src/repo/mod.rs | 6 + crates/lavina-core/src/repo/room.rs | 3 + crates/lavina-core/src/room.rs | 13 + crates/projection-irc/src/lib.rs | 1 + docs/running.md | 14 + src/main.rs | 60 +++- 11 files changed, 569 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9658ff6..ef728a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,6 +132,39 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9" +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + +[[package]] +name = "async-trait" +version = "0.1.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "atoi" version = "2.0.0" @@ -156,6 +189,51 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" +[[package]] +name = "axum" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +dependencies = [ + "async-trait", + "axum-core", + "bitflags 1.3.2", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.28", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 0.2.12", + "http-body 0.4.6", + "mime", + "rustversion", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.71" @@ -360,6 +438,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crossbeam-channel" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-queue" version = "0.3.11" @@ -632,6 +719,37 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "h2" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap 2.2.6", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.14.3" @@ -648,7 +766,7 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown", + "hashbrown 0.14.3", ] [[package]] @@ -705,6 +823,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.1.0" @@ -716,6 +845,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + [[package]] name = "http-body" version = "1.0.0" @@ -723,7 +863,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" dependencies = [ "bytes", - "http", + "http 1.1.0", ] [[package]] @@ -734,8 +874,8 @@ checksum = "0475f8b2ac86659c21b64320d5d653f9efe42acd2a4e560073ec61a155a34f1d" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.1.0", + "http-body 1.0.0", "pin-project-lite", ] @@ -751,6 +891,30 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "hyper" +version = "0.14.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", + "want", +] + [[package]] name = "hyper" version = "1.3.1" @@ -760,8 +924,8 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http", - "http-body", + "http 1.1.0", + "http-body 1.0.0", "httparse", "httpdate", "itoa", @@ -771,6 +935,18 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-timeout" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +dependencies = [ + "hyper 0.14.28", + "pin-project-lite", + "tokio", + "tokio-io-timeout", +] + [[package]] name = "hyper-util" version = "0.1.3" @@ -780,9 +956,9 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http", - "http-body", - "hyper", + "http 1.1.0", + "http-body 1.0.0", + "hyper 1.3.1", "pin-project-lite", "socket2", "tokio", @@ -824,6 +1000,16 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + [[package]] name = "indexmap" version = "2.2.6" @@ -831,7 +1017,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.3", ] [[package]] @@ -881,11 +1067,15 @@ dependencies = [ "figment", "futures-util", "http-body-util", - "hyper", + "hyper 1.3.1", "hyper-util", "lavina-core", "mgmt-api", "nonempty", + "opentelemetry", + "opentelemetry-otlp", + "opentelemetry-semantic-conventions", + "opentelemetry_sdk", "projection-irc", "projection-xmpp", "prometheus", @@ -895,6 +1085,7 @@ dependencies = [ "serde_json", "tokio", "tracing", + "tracing-opentelemetry", "tracing-subscriber", ] @@ -967,6 +1158,12 @@ version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "md-5" version = "0.10.6" @@ -1120,6 +1317,89 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "opentelemetry" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "900d57987be3f2aeb70d385fff9b27fb74c5723cc9a52d904d4f9c807a0667bf" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "once_cell", + "pin-project-lite", + "thiserror", + "urlencoding", +] + +[[package]] +name = "opentelemetry-otlp" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a016b8d9495c639af2145ac22387dcb88e44118e45320d9238fbf4e7889abcb" +dependencies = [ + "async-trait", + "futures-core", + "http 0.2.12", + "opentelemetry", + "opentelemetry-proto", + "opentelemetry-semantic-conventions", + "opentelemetry_sdk", + "prost", + "thiserror", + "tokio", + "tonic", +] + +[[package]] +name = "opentelemetry-proto" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a8fddc9b68f5b80dae9d6f510b88e02396f006ad48cac349411fbecc80caae4" +dependencies = [ + "opentelemetry", + "opentelemetry_sdk", + "prost", + "tonic", +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9ab5bd6c42fb9349dcf28af2ba9a0667f697f9bdcca045d39f2cec5543e2910" + +[[package]] +name = "opentelemetry_sdk" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e90c7113be649e31e9a0f8b5ee24ed7a16923b322c3c5ab6367469c049d6b7e" +dependencies = [ + "async-trait", + "crossbeam-channel", + "futures-channel", + "futures-executor", + "futures-util", + "glob", + "once_cell", + "opentelemetry", + "ordered-float", + "percent-encoding", + "rand", + "thiserror", + "tokio", + "tokio-stream", +] + +[[package]] +name = "ordered-float" +version = "4.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +dependencies = [ + "num-traits", +] + [[package]] name = "overload" version = "0.1.1" @@ -1346,6 +1626,29 @@ dependencies = [ "thiserror", ] +[[package]] +name = "prost" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0f5d036824e4761737860779c906171497f6d55681139d8312388f8fe398922" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-derive" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19de2de2a00075bf566bee3bd4db014b11587e84184d3f7a791bc17f1a8e9e48" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "proto-irc" version = "0.0.2-dev" @@ -1468,10 +1771,10 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "http", - "http-body", + "http 1.1.0", + "http-body 1.0.0", "http-body-util", - "hyper", + "hyper 1.3.1", "hyper-util", "ipnet", "js-sys", @@ -1587,6 +1890,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47" + [[package]] name = "ryu" version = "1.0.17" @@ -1821,7 +2130,7 @@ dependencies = [ "futures-util", "hashlink", "hex", - "indexmap", + "indexmap 2.2.6", "log", "memchr", "once_cell", @@ -2108,6 +2417,16 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "tokio-io-timeout" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" +dependencies = [ + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-macros" version = "2.2.0" @@ -2129,6 +2448,31 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + [[package]] name = "toml" version = "0.8.12" @@ -2156,13 +2500,40 @@ version = "0.22.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb686a972ccef8537b39eead3968b0e8616cb5040dbb9bba93007c8e07c9215f" dependencies = [ - "indexmap", + "indexmap 2.2.6", "serde", "serde_spanned", "toml_datetime", "winnow", ] +[[package]] +name = "tonic" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76c4eb7a4e9ef9d4763600161f12f5070b92a578e1b634db88a6887844c91a13" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64 0.21.7", + "bytes", + "h2", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.28", + "hyper-timeout", + "percent-encoding", + "pin-project", + "prost", + "tokio", + "tokio-stream", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower" version = "0.4.13" @@ -2171,9 +2542,13 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", + "indexmap 1.9.3", "pin-project", "pin-project-lite", + "rand", + "slab", "tokio", + "tokio-util", "tower-layer", "tower-service", "tracing", @@ -2235,6 +2610,24 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-opentelemetry" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9be14ba1bbe4ab79e9229f7f89fab8d120b865859f10527f31c033e599d2284" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry", + "opentelemetry_sdk", + "smallvec", + "tracing", + "tracing-core", + "tracing-log", + "tracing-subscriber", + "web-time", +] + [[package]] name = "tracing-subscriber" version = "0.3.18" @@ -2456,6 +2849,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "whoami" version = "1.5.1" diff --git a/Cargo.toml b/Cargo.toml index 0158e42..fa04096 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,6 +59,11 @@ projection-irc = { path = "crates/projection-irc" } projection-xmpp = { path = "crates/projection-xmpp" } mgmt-api = { path = "crates/mgmt-api" } clap.workspace = true +opentelemetry = "0.22.0" +opentelemetry-semantic-conventions = "0.14.0" +opentelemetry_sdk = { version = "0.22.1", features = ["rt-tokio"] } +opentelemetry-otlp = "0.15.0" +tracing-opentelemetry = "0.23.0" [dev-dependencies] assert_matches.workspace = true diff --git a/crates/lavina-core/src/auth.rs b/crates/lavina-core/src/auth.rs index ccf962d..6e6f4ba 100644 --- a/crates/lavina-core/src/auth.rs +++ b/crates/lavina-core/src/auth.rs @@ -25,6 +25,7 @@ impl<'a> Authenticator<'a> { Self { storage } } + #[tracing::instrument(skip(self, provided_password), name = "Authenticator::authenticate")] pub async fn authenticate(&self, login: &str, provided_password: &str) -> Result { let Some(stored_user) = self.storage.retrieve_user_by_name(login).await? else { return Ok(Verdict::UserNotFound); @@ -46,6 +47,7 @@ impl<'a> Authenticator<'a> { Ok(Verdict::InvalidPassword) } + #[tracing::instrument(skip(self, provided_password), name = "Authenticator::set_password")] pub async fn set_password(&self, login: &str, provided_password: &str) -> Result { let Some(u) = self.storage.retrieve_user_by_name(login).await? else { return Ok(UpdatePasswordResult::UserNotFound); diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 2f5edb4..6dc65de 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -15,6 +15,7 @@ use prometheus::{IntGauge, Registry as MetricsRegistry}; use serde::Serialize; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::RwLock; +use tracing::{Instrument, Span}; use crate::dialog::DialogRegistry; use crate::prelude::*; @@ -59,6 +60,7 @@ pub struct PlayerConnection { } impl PlayerConnection { /// Handled in [Player::send_message]. + #[tracing::instrument(skip(self, body), name = "PlayerConnection::send_message")] pub async fn send_message(&mut self, room_id: RoomId, body: Str) -> Result { let (promise, deferred) = oneshot(); let cmd = ClientCommand::SendMessage { room_id, body, promise }; @@ -67,6 +69,7 @@ impl PlayerConnection { } /// Handled in [Player::join_room]. + #[tracing::instrument(skip(self), name = "PlayerConnection::join_room")] pub async fn join_room(&mut self, room_id: RoomId) -> Result { let (promise, deferred) = oneshot(); let cmd = ClientCommand::JoinRoom { room_id, promise }; @@ -75,6 +78,7 @@ impl PlayerConnection { } /// Handled in [Player::change_topic]. + #[tracing::instrument(skip(self, new_topic), name = "PlayerConnection::change_topic")] pub async fn change_topic(&mut self, room_id: RoomId, new_topic: Str) -> Result<()> { let (promise, deferred) = oneshot(); let cmd = ClientCommand::ChangeTopic { @@ -87,6 +91,7 @@ impl PlayerConnection { } /// Handled in [Player::leave_room]. + #[tracing::instrument(skip(self), name = "PlayerConnection::leave_room")] pub async fn leave_room(&mut self, room_id: RoomId) -> Result<()> { let (promise, deferred) = oneshot(); let cmd = ClientCommand::LeaveRoom { room_id, promise }; @@ -99,6 +104,7 @@ impl PlayerConnection { } /// Handled in [Player::get_rooms]. + #[tracing::instrument(skip(self), name = "PlayerConnection::get_rooms")] pub async fn get_rooms(&self) -> Result> { let (promise, deferred) = oneshot(); let cmd = ClientCommand::GetRooms { promise }; @@ -107,6 +113,7 @@ impl PlayerConnection { } /// Handler in [Player::send_dialog_message]. + #[tracing::instrument(skip(self, body), name = "PlayerConnection::send_dialog_message")] pub async fn send_dialog_message(&self, recipient: PlayerId, body: Str) -> Result<()> { let (promise, deferred) = oneshot(); let cmd = ClientCommand::SendDialogMessage { @@ -122,14 +129,14 @@ impl PlayerConnection { /// Handle to a player actor. #[derive(Clone)] pub struct PlayerHandle { - tx: Sender, + tx: Sender<(ActorCommand, Span)>, } impl PlayerHandle { pub async fn subscribe(&self) -> PlayerConnection { let (sender, receiver) = channel(32); let (promise, deferred) = oneshot(); let cmd = ActorCommand::AddConnection { sender, promise }; - let _ = self.tx.send(cmd).await; + self.send(cmd).await; let connection_id = deferred.await.unwrap(); PlayerConnection { connection_id, @@ -139,8 +146,9 @@ impl PlayerHandle { } async fn send(&self, command: ActorCommand) { + let span = tracing::span!(tracing::Level::INFO, "PlayerHandle::send"); // TODO either handle the error or doc why it is safe to ignore - let _ = self.tx.send(command).await; + let _ = self.tx.send((command, span)).await; } pub async fn update(&self, update: Updates) { @@ -332,7 +340,7 @@ struct Player { connections: AnonTable>, my_rooms: HashMap, banned_from: HashSet, - rx: Receiver, + rx: Receiver<(ActorCommand, Span)>, handle: PlayerHandle, rooms: RoomRegistry, dialogs: DialogRegistry, @@ -379,20 +387,36 @@ impl Player { } } while let Some(cmd) = self.rx.recv().await { - match cmd { - ActorCommand::AddConnection { sender, promise } => { - let connection_id = self.connections.insert(sender); - if let Err(connection_id) = promise.send(ConnectionId(connection_id)) { - log::warn!("Connection {connection_id:?} terminated before finalization"); - self.terminate_connection(connection_id); + let (cmd, span) = cmd; + let should_stop = async { + match cmd { + ActorCommand::AddConnection { sender, promise } => { + let connection_id = self.connections.insert(sender); + if let Err(connection_id) = promise.send(ConnectionId(connection_id)) { + log::warn!("Connection {connection_id:?} terminated before finalization"); + self.terminate_connection(connection_id); + } + false } + ActorCommand::TerminateConnection(connection_id) => { + self.terminate_connection(connection_id); + false + } + ActorCommand::Update(update) => { + self.handle_update(update).await; + false + } + ActorCommand::ClientCommand(cmd, connection_id) => { + self.handle_cmd(cmd, connection_id).await; + false + } + ActorCommand::Stop => true, } - ActorCommand::TerminateConnection(connection_id) => { - self.terminate_connection(connection_id); - } - ActorCommand::Update(update) => self.handle_update(update).await, - ActorCommand::ClientCommand(cmd, connection_id) => self.handle_cmd(cmd, connection_id).await, - ActorCommand::Stop => break, + } + .instrument(span) + .await; + if should_stop { + break; } } log::debug!("Shutting down player actor #{:?}", self.player_id); @@ -400,6 +424,7 @@ impl Player { } /// Handle an incoming update by changing the internal state and broadcasting it to all connections if necessary. + #[tracing::instrument(skip(self, update), name = "Player::handle_update")] async fn handle_update(&mut self, update: Updates) { log::debug!( "Player received an update, broadcasting to {} connections", @@ -461,6 +486,7 @@ impl Player { } } + #[tracing::instrument(skip(self), name = "Player::join_room")] async fn join_room(&mut self, connection_id: ConnectionId, room_id: RoomId) -> JoinResult { if self.banned_from.contains(&room_id) { return JoinResult::Banned; @@ -488,6 +514,7 @@ impl Player { JoinResult::Success(room_info) } + #[tracing::instrument(skip(self), name = "Player::leave_room")] async fn leave_room(&mut self, connection_id: ConnectionId, room_id: RoomId) { let room = self.my_rooms.remove(&room_id); if let Some(room) = room { @@ -501,6 +528,7 @@ impl Player { self.broadcast_update(update, connection_id).await; } + #[tracing::instrument(skip(self, body), name = "Player::send_message")] async fn send_message(&mut self, connection_id: ConnectionId, room_id: RoomId, body: Str) -> SendMessageResult { let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("no room found"); @@ -518,6 +546,7 @@ impl Player { SendMessageResult::Success(created_at) } + #[tracing::instrument(skip(self, new_topic), name = "Player::change_topic")] async fn change_topic(&mut self, connection_id: ConnectionId, room_id: RoomId, new_topic: Str) { let Some(room) = self.my_rooms.get(&room_id) else { tracing::info!("no room found"); @@ -528,6 +557,7 @@ impl Player { self.broadcast_update(update, connection_id).await; } + #[tracing::instrument(skip(self), name = "Player::get_rooms")] async fn get_rooms(&self) -> Vec { let mut response = vec![]; for (_, handle) in &self.my_rooms { @@ -536,6 +566,7 @@ impl Player { response } + #[tracing::instrument(skip(self, body), name = "Player::send_dialog_message")] async fn send_dialog_message(&self, connection_id: ConnectionId, recipient: PlayerId, body: Str) { let created_at = chrono::Utc::now(); self.dialogs.send_message(self.player_id.clone(), recipient.clone(), body.clone(), &created_at).await.unwrap(); @@ -552,6 +583,7 @@ impl Player { /// /// This is called after handling a client command. /// Sending the update to the connection which sent the command is handled by the connection itself. + #[tracing::instrument(skip(self, update), name = "Player::broadcast_update")] async fn broadcast_update(&self, update: Updates, except: ConnectionId) { for (a, b) in &self.connections { if ConnectionId(a) == except { diff --git a/crates/lavina-core/src/repo/auth.rs b/crates/lavina-core/src/repo/auth.rs index f7c0d69..ae67df5 100644 --- a/crates/lavina-core/src/repo/auth.rs +++ b/crates/lavina-core/src/repo/auth.rs @@ -3,6 +3,7 @@ use anyhow::Result; use crate::repo::Storage; impl Storage { + #[tracing::instrument(skip(self), name = "Storage::set_argon2_challenge")] pub async fn set_argon2_challenge(&self, user_id: u32, hash: &str) -> Result<()> { let mut executor = self.conn.lock().await; sqlx::query( diff --git a/crates/lavina-core/src/repo/mod.rs b/crates/lavina-core/src/repo/mod.rs index 9c7aff6..f0d210c 100644 --- a/crates/lavina-core/src/repo/mod.rs +++ b/crates/lavina-core/src/repo/mod.rs @@ -40,6 +40,7 @@ impl Storage { Ok(Storage { conn }) } + #[tracing::instrument(skip(self), name = "Storage::retrieve_user_by_name")] pub async fn retrieve_user_by_name(&self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( @@ -55,6 +56,7 @@ impl Storage { Ok(res) } + #[tracing::instrument(skip(self), name = "Storage::retrieve_room_by_name")] pub async fn retrieve_room_by_name(&self, name: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( @@ -69,6 +71,7 @@ impl Storage { Ok(res) } + #[tracing::instrument(skip(self, topic), name = "Storage::create_new_room")] pub async fn create_new_room(&mut self, name: &str, topic: &str) -> Result { let mut executor = self.conn.lock().await; let (id,): (u32,) = sqlx::query_as( @@ -84,6 +87,7 @@ impl Storage { Ok(id) } + #[tracing::instrument(skip(self, content, created_at), name = "Storage::insert_message")] pub async fn insert_message( &mut self, room_id: u32, @@ -127,6 +131,7 @@ impl Storage { Ok(()) } + #[tracing::instrument(skip(self), name = "Storage::create_user")] pub async fn create_user(&mut self, name: &str) -> Result<()> { let query = sqlx::query( "insert into users(name) @@ -139,6 +144,7 @@ impl Storage { Ok(()) } + #[tracing::instrument(skip(self, pwd), name = "Storage::set_password")] pub async fn set_password<'a>(&'a self, name: &'a str, pwd: &'a str) -> Result> { async fn inner(txn: &mut Transaction<'_, Sqlite>, name: &str, pwd: &str) -> Result> { let id: Option<(u32,)> = sqlx::query_as("select * from users where name = ? limit 1;") diff --git a/crates/lavina-core/src/repo/room.rs b/crates/lavina-core/src/repo/room.rs index 96b89f2..38de47d 100644 --- a/crates/lavina-core/src/repo/room.rs +++ b/crates/lavina-core/src/repo/room.rs @@ -3,6 +3,7 @@ use anyhow::Result; use crate::repo::Storage; impl Storage { + #[tracing::instrument(skip(self), name = "Storage::add_room_member")] pub async fn add_room_member(&self, room_id: u32, player_id: u32) -> Result<()> { let mut executor = self.conn.lock().await; sqlx::query( @@ -17,6 +18,7 @@ impl Storage { Ok(()) } + #[tracing::instrument(skip(self), name = "Storage::remove_room_member")] pub async fn remove_room_member(&self, room_id: u32, player_id: u32) -> Result<()> { let mut executor = self.conn.lock().await; sqlx::query( @@ -31,6 +33,7 @@ impl Storage { Ok(()) } + #[tracing::instrument(skip(self, topic), name = "Storage::set_room_topic")] pub async fn set_room_topic(&mut self, id: u32, topic: &str) -> Result<()> { let mut executor = self.conn.lock().await; sqlx::query( diff --git a/crates/lavina-core/src/room.rs b/crates/lavina-core/src/room.rs index d50e169..17a463b 100644 --- a/crates/lavina-core/src/room.rs +++ b/crates/lavina-core/src/room.rs @@ -59,6 +59,7 @@ impl RoomRegistry { Ok(()) } + #[tracing::instrument(skip(self), name = "RoomRegistry::get_or_create_room")] pub async fn get_or_create_room(&mut self, room_id: RoomId) -> Result { let mut inner = self.0.write().await; if let Some(room_handle) = inner.get_or_load_room(&room_id).await? { @@ -83,11 +84,13 @@ impl RoomRegistry { } } + #[tracing::instrument(skip(self), name = "RoomRegistry::get_room")] pub async fn get_room(&self, room_id: &RoomId) -> Option { let mut inner = self.0.write().await; inner.get_or_load_room(room_id).await.unwrap() } + #[tracing::instrument(skip(self), name = "RoomRegistry::get_all_rooms")] pub async fn get_all_rooms(&self) -> Vec { let handles = { let inner = self.0.read().await; @@ -109,6 +112,7 @@ struct RoomRegistryInner { } impl RoomRegistryInner { + #[tracing::instrument(skip(self), name = "RoomRegistryInner::get_or_load_room")] async fn get_or_load_room(&mut self, room_id: &RoomId) -> Result> { if let Some(room_handle) = self.rooms.get(room_id) { log::debug!("Room {} was loaded already", &room_id.0); @@ -138,12 +142,14 @@ impl RoomRegistryInner { #[derive(Clone)] pub struct RoomHandle(Arc>); impl RoomHandle { + #[tracing::instrument(skip(self, player_handle), name = "RoomHandle::subscribe")] pub async fn subscribe(&self, player_id: &PlayerId, player_handle: PlayerHandle) { let mut lock = self.0.write().await; tracing::info!("Adding a subscriber to a room"); lock.subscriptions.insert(player_id.clone(), player_handle); } + #[tracing::instrument(skip(self), name = "RoomHandle::add_member")] pub async fn add_member(&self, player_id: &PlayerId, player_storage_id: u32) { let mut lock = self.0.write().await; tracing::info!("Adding a new member to a room"); @@ -157,11 +163,13 @@ impl RoomHandle { lock.broadcast_update(update, player_id).await; } + #[tracing::instrument(skip(self), name = "RoomHandle::unsubscribe")] pub async fn unsubscribe(&self, player_id: &PlayerId) { let mut lock = self.0.write().await; lock.subscriptions.remove(player_id); } + #[tracing::instrument(skip(self), name = "RoomHandle::remove_member")] pub async fn remove_member(&self, player_id: &PlayerId, player_storage_id: u32) { let mut lock = self.0.write().await; tracing::info!("Removing a member from a room"); @@ -175,6 +183,7 @@ impl RoomHandle { lock.broadcast_update(update, player_id).await; } + #[tracing::instrument(skip(self, body, created_at), name = "RoomHandle::send_message")] pub async fn send_message(&self, player_id: &PlayerId, body: Str, created_at: DateTime) { let mut lock = self.0.write().await; let res = lock.send_message(player_id, body, created_at).await; @@ -183,6 +192,7 @@ impl RoomHandle { } } + #[tracing::instrument(skip(self), name = "RoomHandle::get_room_info")] pub async fn get_room_info(&self) -> RoomInfo { let lock = self.0.read().await; RoomInfo { @@ -192,6 +202,7 @@ impl RoomHandle { } } + #[tracing::instrument(skip(self, new_topic), name = "RoomHandle::set_topic")] pub async fn set_topic(&self, changer_id: &PlayerId, new_topic: Str) { let mut lock = self.0.write().await; let storage_id = lock.storage_id; @@ -220,6 +231,7 @@ struct Room { storage: Storage, } impl Room { + #[tracing::instrument(skip(self, body, created_at), name = "Room::send_message")] async fn send_message(&mut self, author_id: &PlayerId, body: Str, created_at: DateTime) -> Result<()> { tracing::info!("Adding a message to room"); self.storage @@ -246,6 +258,7 @@ impl Room { /// /// This is called after handling a client command. /// Sending the update to the player who sent the command is handled by the player actor. + #[tracing::instrument(skip(self, update), name = "Room::broadcast_update")] async fn broadcast_update(&self, update: Updates, except: &PlayerId) { tracing::debug!("Broadcasting an update to {} subs", self.subscriptions.len()); for (player_id, sub) in &self.subscriptions { diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index 342682a..a320c5d 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -678,6 +678,7 @@ enum HandleResult { Leave, } +#[tracing::instrument(skip_all, name = "handle_incoming_message")] async fn handle_incoming_message( buffer: &str, config: &ServerConfig, diff --git a/docs/running.md b/docs/running.md index ad422a1..74b4c60 100644 --- a/docs/running.md +++ b/docs/running.md @@ -23,6 +23,11 @@ hostname = "localhost" [storage] db_path = "db.sqlite" + +[tracing] +# otlp grpc endpoint +endpoint = "http://jaeger:4317" +service_name = "lavina" ``` ## With Docker Compose @@ -41,6 +46,15 @@ services: - '5222:5222' # xmpp - '6667:6667' # irc non-tls - '127.0.0.1:1380:8080' # management http (private) + # if you want to observe traces + jaeger: + image: "jaegertracing/all-in-one:1.56" + ports: + - "16686:16686" # web ui + - "4317:4317" # grpc ingest endpoint + environment: + - COLLECTOR_OTLP_ENABLED=true + - SPAN_STORAGE_TYPE=memory ``` ## With Cargo diff --git a/src/main.rs b/src/main.rs index 98c45f8..d72055b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,8 +6,17 @@ use std::path::Path; use clap::Parser; use figment::providers::Format; use figment::{providers::Toml, Figment}; +use opentelemetry::KeyValue; +use opentelemetry_otlp::WithExportConfig; +use opentelemetry_sdk::trace::{BatchConfig, RandomIdGenerator, Sampler}; +use opentelemetry_sdk::{runtime, Resource}; +use opentelemetry_semantic_conventions::resource::SERVICE_NAME; +use opentelemetry_semantic_conventions::SCHEMA_URL; use prometheus::Registry as MetricsRegistry; use serde::Deserialize; +use tracing_opentelemetry::OpenTelemetryLayer; +use tracing_subscriber::fmt::Subscriber; +use tracing_subscriber::prelude::*; use lavina_core::prelude::*; use lavina_core::repo::Storage; @@ -19,6 +28,13 @@ struct ServerConfig { irc: projection_irc::ServerConfig, xmpp: projection_xmpp::ServerConfig, storage: lavina_core::repo::StorageConfig, + tracing: Option, +} + +#[derive(Deserialize, Debug)] +struct TracingConfig { + endpoint: String, + service_name: String, } #[derive(Parser)] @@ -36,9 +52,9 @@ fn load_config() -> Result { #[tokio::main] async fn main() -> Result<()> { - set_up_logging()?; let sleep = ctrl_c()?; let config = load_config()?; + set_up_logging(&config.tracing)?; tracing::info!("Booting up"); tracing::info!("Loaded config: {config:?}"); @@ -47,6 +63,7 @@ async fn main() -> Result<()> { irc: irc_config, xmpp: xmpp_config, storage: storage_config, + tracing: _, } = config; let metrics = MetricsRegistry::new(); let storage = Storage::open(storage_config).await?; @@ -87,7 +104,44 @@ fn ctrl_c() -> Result> { Ok(recv(chan)) } -fn set_up_logging() -> Result<()> { - tracing_subscriber::fmt::init(); +fn set_up_logging(tracing_config: &Option) -> Result<()> { + let subscriber = tracing_subscriber::registry().with(tracing_subscriber::fmt::layer()); + + let targets = { + use std::{env, str::FromStr}; + use tracing_subscriber::{filter::Targets, layer::SubscriberExt}; + match env::var("RUST_LOG") { + Ok(var) => Targets::from_str(&var) + .map_err(|e| { + eprintln!("Ignoring `RUST_LOG={:?}`: {}", var, e); + }) + .unwrap_or_default(), + Err(env::VarError::NotPresent) => Targets::new().with_default(Subscriber::DEFAULT_MAX_LEVEL), + Err(e) => { + eprintln!("Ignoring `RUST_LOG`: {}", e); + Targets::new().with_default(Subscriber::DEFAULT_MAX_LEVEL) + } + } + }; + if let Some(config) = tracing_config { + let trace_config = opentelemetry_sdk::trace::Config::default() + .with_sampler(Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased(1.0)))) + .with_id_generator(RandomIdGenerator::default()) + .with_resource(Resource::from_schema_url( + [KeyValue::new(SERVICE_NAME, config.service_name.to_string())], + SCHEMA_URL, + )); + let trace_exporter = opentelemetry_otlp::new_exporter().tonic().with_endpoint(&config.endpoint); + let tracer = opentelemetry_otlp::new_pipeline() + .tracing() + .with_trace_config(trace_config) + .with_batch_config(BatchConfig::default()) + .with_exporter(trace_exporter) + .install_batch(runtime::Tokio)?; + let subscriber = subscriber.with(OpenTelemetryLayer::new(tracer)); + targets.with_subscriber(subscriber).try_init()?; + } else { + targets.with_subscriber(subscriber).try_init()?; + } Ok(()) } From 72f5010988a758be2710287657000b5adae1ead4 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Fri, 26 Apr 2024 12:28:13 +0200 Subject: [PATCH 28/37] clean up http.rs a little --- src/http.rs | 45 ++++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/http.rs b/src/http.rs index 4bf3ffe..b39a6f2 100644 --- a/src/http.rs +++ b/src/http.rs @@ -111,14 +111,7 @@ async fn endpoint_create_player( ) -> Result>> { let str = request.collect().await?.to_bytes(); let Ok(res) = serde_json::from_slice::(&str[..]) else { - let payload = ErrorResponse { - code: errors::MALFORMED_REQUEST, - message: "The request payload contains incorrect JSON value", - } - .to_body(); - let mut response = Response::new(payload); - *response.status_mut() = StatusCode::BAD_REQUEST; - return Ok(response); + return Ok(malformed_request()); }; storage.create_user(&res.name).await?; log::info!("Player {} created", res.name); @@ -129,18 +122,11 @@ async fn endpoint_create_player( async fn endpoint_set_password( request: Request, - mut storage: Storage, + storage: Storage, ) -> Result>> { let str = request.collect().await?.to_bytes(); let Ok(res) = serde_json::from_slice::(&str[..]) else { - let payload = ErrorResponse { - code: errors::MALFORMED_REQUEST, - message: "The request payload contains incorrect JSON value", - } - .to_body(); - let mut response = Response::new(payload); - *response.status_mut() = StatusCode::BAD_REQUEST; - return Ok(response); + return Ok(malformed_request()); }; let verdict = Authenticator::new(&storage).set_password(&res.player_name, &res.password).await?; match verdict { @@ -173,19 +159,28 @@ pub fn not_found() -> Response> { response } +fn malformed_request() -> Response> { + let payload = ErrorResponse { + code: errors::MALFORMED_REQUEST, + message: "The request payload contains incorrect JSON value", + } + .to_body(); + + let mut response = Response::new(payload); + *response.status_mut() = StatusCode::BAD_REQUEST; + return response; +} + trait Or5xx { fn or5xx(self) -> Response>; } impl Or5xx for Result>> { fn or5xx(self) -> Response> { - match self { - Ok(e) => e, - Err(e) => { - let mut response = Response::new(Full::new(e.to_string().into())); - *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - response - } - } + self.unwrap_or_else(|e| { + let mut response = Response::new(Full::new(e.to_string().into())); + *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + response + }) } } From 843d0e9c828b03a5e497e4adc42d175ff5fd28bc Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Fri, 26 Apr 2024 13:30:57 +0200 Subject: [PATCH 29/37] bump version --- Cargo.lock | 16 ++++++++-------- Cargo.toml | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ef728a6..83e0d93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1058,7 +1058,7 @@ dependencies = [ [[package]] name = "lavina" -version = "0.0.2-dev" +version = "0.0.2" dependencies = [ "anyhow", "assert_matches", @@ -1091,7 +1091,7 @@ dependencies = [ [[package]] name = "lavina-core" -version = "0.0.2-dev" +version = "0.0.2" dependencies = [ "anyhow", "argon2", @@ -1182,7 +1182,7 @@ checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "mgmt-api" -version = "0.0.2-dev" +version = "0.0.2" dependencies = [ "serde", ] @@ -1573,7 +1573,7 @@ dependencies = [ [[package]] name = "projection-irc" -version = "0.0.2-dev" +version = "0.0.2" dependencies = [ "anyhow", "bitflags 2.5.0", @@ -1592,7 +1592,7 @@ dependencies = [ [[package]] name = "projection-xmpp" -version = "0.0.2-dev" +version = "0.0.2" dependencies = [ "anyhow", "assert_matches", @@ -1651,7 +1651,7 @@ dependencies = [ [[package]] name = "proto-irc" -version = "0.0.2-dev" +version = "0.0.2" dependencies = [ "anyhow", "assert_matches", @@ -1663,7 +1663,7 @@ dependencies = [ [[package]] name = "proto-xmpp" -version = "0.0.2-dev" +version = "0.0.2" dependencies = [ "anyhow", "assert_matches", @@ -1904,7 +1904,7 @@ checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "sasl" -version = "0.0.2-dev" +version = "0.0.2" dependencies = [ "anyhow", "base64 0.22.0", diff --git a/Cargo.toml b/Cargo.toml index fa04096..e19222b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ members = [ ] [workspace.package] -version = "0.0.2-dev" +version = "0.0.2" [workspace.dependencies] nom = "7.1.3" From 4b5ab02322e212915995ecc4da0ca7b47bf5e64c Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Fri, 26 Apr 2024 13:43:43 +0200 Subject: [PATCH 30/37] start next version --- Cargo.lock | 16 ++++++++-------- Cargo.toml | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 83e0d93..e2d52e1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1058,7 +1058,7 @@ dependencies = [ [[package]] name = "lavina" -version = "0.0.2" +version = "0.0.3-dev" dependencies = [ "anyhow", "assert_matches", @@ -1091,7 +1091,7 @@ dependencies = [ [[package]] name = "lavina-core" -version = "0.0.2" +version = "0.0.3-dev" dependencies = [ "anyhow", "argon2", @@ -1182,7 +1182,7 @@ checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "mgmt-api" -version = "0.0.2" +version = "0.0.3-dev" dependencies = [ "serde", ] @@ -1573,7 +1573,7 @@ dependencies = [ [[package]] name = "projection-irc" -version = "0.0.2" +version = "0.0.3-dev" dependencies = [ "anyhow", "bitflags 2.5.0", @@ -1592,7 +1592,7 @@ dependencies = [ [[package]] name = "projection-xmpp" -version = "0.0.2" +version = "0.0.3-dev" dependencies = [ "anyhow", "assert_matches", @@ -1651,7 +1651,7 @@ dependencies = [ [[package]] name = "proto-irc" -version = "0.0.2" +version = "0.0.3-dev" dependencies = [ "anyhow", "assert_matches", @@ -1663,7 +1663,7 @@ dependencies = [ [[package]] name = "proto-xmpp" -version = "0.0.2" +version = "0.0.3-dev" dependencies = [ "anyhow", "assert_matches", @@ -1904,7 +1904,7 @@ checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "sasl" -version = "0.0.2" +version = "0.0.3-dev" dependencies = [ "anyhow", "base64 0.22.0", diff --git a/Cargo.toml b/Cargo.toml index e19222b..093e62c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ members = [ ] [workspace.package] -version = "0.0.2" +version = "0.0.3-dev" [workspace.dependencies] nom = "7.1.3" From ea81ddadfc72ecbc01ed25c24881c64037933e81 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sat, 27 Apr 2024 12:58:27 +0200 Subject: [PATCH 31/37] dialog message persistence --- crates/lavina-core/src/dialog.rs | 21 ++++++++++++++--- crates/lavina-core/src/repo/dialog.rs | 33 +++++++++++++++++++++++---- crates/projection-xmpp/src/lib.rs | 1 + 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/crates/lavina-core/src/dialog.rs b/crates/lavina-core/src/dialog.rs index f87294c..f06d5e8 100644 --- a/crates/lavina-core/src/dialog.rs +++ b/crates/lavina-core/src/dialog.rs @@ -65,7 +65,10 @@ impl DialogRegistry { let dialog = guard.dialogs.get(&id); if let Some(d) = dialog { let mut d = d.write().await; - guard.storage.increment_dialog_message_count(d.storage_id).await?; + guard + .storage + .insert_dialog_message(d.storage_id, d.message_count, from.as_inner(), &body, created_at) + .await?; d.message_count += 1; } else { drop(guard); @@ -73,7 +76,10 @@ impl DialogRegistry { // double check in case concurrent access has loaded this dialog if let Some(d) = guard2.dialogs.get(&id) { let mut d = d.write().await; - guard2.storage.increment_dialog_message_count(d.storage_id).await?; + guard2 + .storage + .insert_dialog_message(d.storage_id, d.message_count, from.as_inner(), &body, created_at) + .await?; d.message_count += 1; } else { let (p1, p2) = id.as_inner(); @@ -86,7 +92,16 @@ impl DialogRegistry { } }; tracing::info!("Dialog {id:?} loaded"); - guard2.storage.increment_dialog_message_count(stored_dialog.id).await?; + guard2 + .storage + .insert_dialog_message( + stored_dialog.id, + stored_dialog.message_count, + from.as_inner(), + &body, + created_at, + ) + .await?; let dialog = Dialog { storage_id: stored_dialog.id, player_storage_id_1: stored_dialog.participant_1, diff --git a/crates/lavina-core/src/repo/dialog.rs b/crates/lavina-core/src/repo/dialog.rs index cbe3161..e228303 100644 --- a/crates/lavina-core/src/repo/dialog.rs +++ b/crates/lavina-core/src/repo/dialog.rs @@ -1,10 +1,11 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use chrono::{DateTime, Utc}; use sqlx::FromRow; use crate::repo::Storage; impl Storage { + #[tracing::instrument(skip(self), name = "Storage::retrieve_dialog")] pub async fn retrieve_dialog(&self, participant_1: &str, participant_2: &str) -> Result> { let mut executor = self.conn.lock().await; let res = sqlx::query_as( @@ -20,19 +21,41 @@ impl Storage { Ok(res) } - pub async fn increment_dialog_message_count(&self, storage_id: u32) -> Result<()> { + #[tracing::instrument(skip(self, content, created_at), name = "Storage::insert_dialog_message")] + pub async fn insert_dialog_message( + &self, + dialog_id: u32, + id: u32, + author_id: &str, + content: &str, + created_at: &DateTime, + ) -> Result<()> { let mut executor = self.conn.lock().await; + let res: Option<(u32,)> = sqlx::query_as("select id from users where name = ?;") + .bind(author_id) + .fetch_optional(&mut *executor) + .await?; + let Some((author_id,)) = res else { + return Err(anyhow!("No such user")); + }; sqlx::query( - "update rooms set message_count = message_count + 1 - where id = ?;", + "insert into dialog_messages(dialog_id, id, author_id, content, created_at) + values (?, ?, ?, ?, ?); + update dialogs set message_count = message_count + 1 where id = ?;", ) - .bind(storage_id) + .bind(dialog_id) + .bind(id) + .bind(author_id) + .bind(content) + .bind(created_at) + .bind(dialog_id) .execute(&mut *executor) .await?; Ok(()) } + #[tracing::instrument(skip(self, created_at), name = "Storage::initialize_dialog")] pub async fn initialize_dialog( &self, participant_1: &str, diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 01f0171..d79bd50 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -428,6 +428,7 @@ struct XmppConnection<'a> { } impl<'a> XmppConnection<'a> { + #[tracing::instrument(skip(self, output, packet), name = "XmppConnection::handle_packet")] async fn handle_packet(&mut self, output: &mut Vec>, packet: ClientPacket) -> Result { let res = match packet { ClientPacket::Iq(iq) => { From a047d55ab5d2093052413ded5ec5db393d155851 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sun, 28 Apr 2024 15:43:22 +0200 Subject: [PATCH 32/37] xmpp: handle correctly unavailable self-presence and improve basic test scenario --- crates/projection-xmpp/src/lib.rs | 1 + crates/projection-xmpp/src/presence.rs | 40 +++++++++++++++----------- crates/projection-xmpp/src/proto.rs | 2 ++ crates/projection-xmpp/tests/lib.rs | 31 +++++++++++++++++++- 4 files changed, 57 insertions(+), 17 deletions(-) diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index d79bd50..5468c19 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -447,6 +447,7 @@ impl<'a> XmppConnection<'a> { ServerStreamEnd.serialize(output); true } + ClientPacket::Eos => true, }; Ok(res) } diff --git a/crates/projection-xmpp/src/presence.rs b/crates/projection-xmpp/src/presence.rs index 82ccb61..c9fc938 100644 --- a/crates/projection-xmpp/src/presence.rs +++ b/crates/projection-xmpp/src/presence.rs @@ -14,7 +14,7 @@ impl<'a> XmppConnection<'a> { pub async fn handle_presence(&mut self, output: &mut Vec>, p: Presence) -> Result<()> { match p.to { None => { - self.self_presence(output).await; + self.self_presence(output, p.r#type.as_deref()).await; } Some(Jid { name: Some(name), @@ -33,21 +33,29 @@ impl<'a> XmppConnection<'a> { Ok(()) } - async fn self_presence(&mut self, output: &mut Vec>) { - let response = Presence::<()> { - to: Some(Jid { - name: Some(self.user.xmpp_name.clone()), - server: Server(self.hostname.clone()), - resource: Some(self.user.xmpp_resource.clone()), - }), - from: Some(Jid { - name: Some(self.user.xmpp_name.clone()), - server: Server(self.hostname.clone()), - resource: Some(self.user.xmpp_resource.clone()), - }), - ..Default::default() - }; - response.serialize(output); + async fn self_presence(&mut self, output: &mut Vec>, r#type: Option<&str>) { + match r#type { + Some("unavailable") => { + // do not print anything + } + None => { + let response = Presence::<()> { + to: Some(Jid { + name: Some(self.user.xmpp_name.clone()), + server: Server(self.hostname.clone()), + resource: Some(self.user.xmpp_resource.clone()), + }), + from: Some(Jid { + name: Some(self.user.xmpp_name.clone()), + server: Server(self.hostname.clone()), + resource: Some(self.user.xmpp_resource.clone()), + }), + ..Default::default() + }; + response.serialize(output); + } + _ => todo!(), + } } async fn muc_presence(&mut self, name: Name, output: &mut Vec>) -> Result<()> { diff --git a/crates/projection-xmpp/src/proto.rs b/crates/projection-xmpp/src/proto.rs index e486b65..8b16ef0 100644 --- a/crates/projection-xmpp/src/proto.rs +++ b/crates/projection-xmpp/src/proto.rs @@ -52,6 +52,7 @@ pub enum ClientPacket { Message(Message), Presence(Presence), StreamEnd, + Eos, } impl FromXml for ClientPacket { @@ -83,6 +84,7 @@ impl FromXml for ClientPacket { return Err(anyhow!("Unexpected XML event: {event:?}")); } } + Event::Eof => Ok(ClientPacket::Eos), _ => { return Err(anyhow!("Unexpected XML event: {event:?}")); } diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 88af83d..ef8c046 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -82,7 +82,7 @@ impl<'a> TestScopeTls<'a> { fn new(stream: &'a mut TlsStream, buffer: Vec) -> TestScopeTls<'a> { let (reader, writer) = tokio::io::split(stream); let reader = NsReader::from_reader(BufReader::new(reader)); - let timeout = Duration::from_millis(100); + let timeout = Duration::from_millis(500); TestScopeTls { reader, @@ -203,6 +203,35 @@ async fn scenario_basic() -> Result<()> { assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); + assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"PLAIN")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + + // base64-encoded b"\x00tester\x00password" + s.send(r#"AHRlc3RlcgBwYXNzd29yZA=="#) + .await?; + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"success")); + s.send(r#""#).await?; + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"bind")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + s.send(r#"kek"#).await?; + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"iq")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"bind")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"jid")); + assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"tester@localhost/tester")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"jid")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"bind")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"iq")); + s.send(r#"Logged out"#).await?; + stream.shutdown().await?; // wrap up From 8ec9ecfe2cceddd3a05fd1ab36318eb6ce1c8c61 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sun, 28 Apr 2024 17:11:29 +0200 Subject: [PATCH 33/37] xmpp: handle incorrect credentials by replying with an error --- crates/projection-xmpp/src/lib.rs | 13 +++--- crates/projection-xmpp/tests/lib.rs | 67 +++++++++++++++++++++++++++++ crates/proto-xmpp/src/sasl.rs | 15 ++++++- 3 files changed, 87 insertions(+), 8 deletions(-) diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 5468c19..fe56481 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -296,20 +296,19 @@ async fn socket_auth( xml_writer.get_mut().flush().await?; let auth: proto_xmpp::sasl::Auth = proto_xmpp::sasl::Auth::parse(xml_reader, reader_buf).await?; - proto_xmpp::sasl::Success.write_xml(xml_writer).await?; - xml_writer.get_mut().flush().await?; match AuthBody::from_str(&auth.body) { Ok(logopass) => { let name = &logopass.login; let verdict = Authenticator::new(storage).authenticate(name, &logopass.password).await?; - // TODO return proper XML errors to the client match verdict { - Verdict::Authenticated => {} - Verdict::UserNotFound => { - return Err(anyhow!("no user found")); + Verdict::Authenticated => { + proto_xmpp::sasl::Success.write_xml(xml_writer).await?; + xml_writer.get_mut().flush().await?; } - Verdict::InvalidPassword => { + Verdict::UserNotFound | Verdict::InvalidPassword => { + proto_xmpp::sasl::Failure.write_xml(xml_writer).await?; + xml_writer.get_mut().flush().await?; return Err(anyhow!("incorrect credentials")); } } diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index ef8c046..8c05128 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -240,6 +240,73 @@ async fn scenario_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scenario_wrong_password() -> Result<()> { + let mut server = TestServer::start().await?; + + // test scenario + + server.storage.create_user("tester").await?; + Authenticator::new(&server.storage).set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.server.addr).await?; + let mut s = TestScope::new(&mut stream); + tracing::info!("TCP connection established"); + + s.send(r#""#).await?; + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); + let buffer = s.buffer; + tracing::info!("TLS feature negotiation complete"); + + let connector = TlsConnector::from(Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(Arc::new(IgnoreCertVerification)) + .with_no_client_auth(), + )); + tracing::info!("Initiating TLS connection..."); + let mut stream = connector.connect(ServerName::IpAddress(server.server.addr.ip()), stream).await?; + tracing::info!("TLS connection established"); + + let mut s = TestScopeTls::new(&mut stream, buffer); + + s.send(r#""#).await?; + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); + assert_matches!(s.next_xml_event().await?, Event::Text(b) => assert_eq!(&*b, b"PLAIN")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanism")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"mechanisms")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + + // base64-encoded b"\x00tester\x00password2" + s.send(r#"AHRlc3RlcgBwYXNzd29yZDI="#) + .await?; + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"failure")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"not-authorized")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"failure")); + + stream.shutdown().await?; + + // wrap up + + server.shutdown().await?; + Ok(()) +} + #[tokio::test] async fn scenario_basic_without_headers() -> Result<()> { let mut server = TestServer::start().await?; diff --git a/crates/proto-xmpp/src/sasl.rs b/crates/proto-xmpp/src/sasl.rs index e147962..b042f09 100644 --- a/crates/proto-xmpp/src/sasl.rs +++ b/crates/proto-xmpp/src/sasl.rs @@ -1,7 +1,7 @@ use std::borrow::Borrow; use anyhow::{anyhow, Result}; -use quick_xml::events::{BytesStart, Event}; +use quick_xml::events::{BytesEnd, BytesStart, Event}; use quick_xml::{NsReader, Writer}; use tokio::io::{AsyncBufRead, AsyncWrite}; @@ -74,3 +74,16 @@ impl Success { Ok(()) } } + +pub struct Failure; +impl Failure { + pub async fn write_xml(&self, writer: &mut Writer) -> Result<()> { + let event = BytesStart::new(r#"failure xmlns="urn:ietf:params:xml:ns:xmpp-sasl""#); + writer.write_event_async(Event::Start(event)).await?; + let event = BytesStart::new(r#"not-authorized"#); + writer.write_event_async(Event::Empty(event)).await?; + let event = BytesEnd::new(r#"failure"#); + writer.write_event_async(Event::End(event)).await?; + Ok(()) + } +} From c69513f38b59b9df519b0c56a5e46a8108937644 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sun, 28 Apr 2024 17:19:31 +0200 Subject: [PATCH 34/37] xmpp: use mutable namespace and event in parser coroutines --- crates/projection-xmpp/src/proto.rs | 4 ++-- crates/proto-xmpp/src/bind.rs | 8 ++++---- crates/proto-xmpp/src/client.rs | 18 ++++++++--------- crates/proto-xmpp/src/disco.rs | 30 ++++++++++++++--------------- crates/proto-xmpp/src/muc/mod.rs | 14 +++++++------- crates/proto-xmpp/src/xml/mod.rs | 4 ++-- src/main.rs | 2 +- 7 files changed, 40 insertions(+), 40 deletions(-) diff --git a/crates/projection-xmpp/src/proto.rs b/crates/projection-xmpp/src/proto.rs index 8b16ef0..0b157a6 100644 --- a/crates/projection-xmpp/src/proto.rs +++ b/crates/projection-xmpp/src/proto.rs @@ -25,7 +25,7 @@ impl FromXml for IqClientBody { type P = impl Parser>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { let bytes = match event { Event::Start(bytes) => bytes, Event::Empty(bytes) => bytes, @@ -59,7 +59,7 @@ impl FromXml for ClientPacket { type P = impl Parser>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { match event { Event::Start(bytes) | Event::Empty(bytes) => { let name = bytes.name(); diff --git a/crates/proto-xmpp/src/bind.rs b/crates/proto-xmpp/src/bind.rs index dc0d1ce..d27a00e 100644 --- a/crates/proto-xmpp/src/bind.rs +++ b/crates/proto-xmpp/src/bind.rs @@ -82,7 +82,7 @@ impl FromXml for BindRequest { type P = impl Parser>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { let mut resource: Option = None; let Event::Start(bytes) = event else { return Err(anyhow!("Unexpected XML event: {event:?}")); @@ -97,15 +97,15 @@ impl FromXml for BindRequest { return Err(anyhow!("Incorrect namespace")); } loop { - let (namespace, event) = yield; + (namespace, event) = yield; match event { Event::Start(bytes) if bytes.name().0 == b"resource" => { - let (namespace, event) = yield; + (namespace, event) = yield; let Event::Text(text) = event else { return Err(anyhow!("Unexpected XML event: {event:?}")); }; resource = Some(std::str::from_utf8(&*text)?.into()); - let (namespace, event) = yield; + (namespace, event) = yield; let Event::End(bytes) = event else { return Err(anyhow!("Unexpected XML event: {event:?}")); }; diff --git a/crates/proto-xmpp/src/client.rs b/crates/proto-xmpp/src/client.rs index 85b3979..05807bd 100644 --- a/crates/proto-xmpp/src/client.rs +++ b/crates/proto-xmpp/src/client.rs @@ -378,7 +378,7 @@ impl Parser for IqParser { } }, IqParserInner::Final(state) => { - if let Event::End(ref bytes) = event { + if let Event::End(_) = event { let id = fail_fast!(state.id.ok_or_else(|| ffail!("No id provided"))); let r#type = fail_fast!(state.r#type.ok_or_else(|| ffail!("No type provided"))); let body = fail_fast!(state.body.ok_or_else(|| ffail!("No body provided"))); @@ -528,7 +528,7 @@ impl FromXml for Presence { type P = impl Parser>>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { let (bytes, end) = match event { Event::Start(bytes) => (bytes, false), Event::Empty(bytes) => (bytes, true), @@ -557,37 +557,37 @@ impl FromXml for Presence { return Ok(p); } loop { - let (namespace, event) = yield; + (namespace, event) = yield; match event { Event::Start(bytes) => match bytes.name().0 { b"show" => { - let (_, event) = yield; + (namespace, event) = yield; let Event::Text(bytes) = event else { return Err(ffail!("Unexpected XML event: {event:?}")); }; let i = PresenceShow::from_str(bytes)?; p.show = Some(i); - let (_, event) = yield; + (namespace, event) = yield; let Event::End(_) = event else { return Err(ffail!("Unexpected XML event: {event:?}")); }; } b"status" => { - let (_, event) = yield; + (namespace, event) = yield; let Event::Text(bytes) = event else { return Err(ffail!("Unexpected XML event: {event:?}")); }; let s = std::str::from_utf8(bytes)?; p.status.push(s.to_string()); - let (_, event) = yield; + (namespace, event) = yield; let Event::End(_) = event else { return Err(ffail!("Unexpected XML event: {event:?}")); }; } b"priority" => { - let (_, event) = yield; + (namespace, event) = yield; let Event::Text(bytes) = event else { return Err(ffail!("Unexpected XML event: {event:?}")); }; @@ -595,7 +595,7 @@ impl FromXml for Presence { let i = s.parse()?; p.priority = Some(PresencePriority(i)); - let (_, event) = yield; + (namespace, event) = yield; let Event::End(_) = event else { return Err(ffail!("Unexpected XML event: {event:?}")); }; diff --git a/crates/proto-xmpp/src/disco.rs b/crates/proto-xmpp/src/disco.rs index af7771b..bea4611 100644 --- a/crates/proto-xmpp/src/disco.rs +++ b/crates/proto-xmpp/src/disco.rs @@ -21,7 +21,7 @@ impl FromXml for InfoQuery { type P = impl Parser>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { let mut node = None; let mut identity = vec![]; let mut feature = vec![]; @@ -48,7 +48,7 @@ impl FromXml for InfoQuery { }); } loop { - let (namespace, event) = yield; + (namespace, event) = yield; let bytes = match event { Event::Start(bytes) => bytes, Event::Empty(bytes) => bytes, @@ -141,7 +141,7 @@ impl FromXml for Identity { type P = impl Parser>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { let mut category = None; let mut name = None; let mut r#type = None; @@ -179,8 +179,8 @@ impl FromXml for Identity { return Ok(item); } - let (namespace, event) = yield; - let Event::End(bytes) = event else { + (namespace, event) = yield; + let Event::End(_) = event else { return Err(ffail!("Unexpected XML event: {event:?}")); }; Ok(item) @@ -209,7 +209,7 @@ impl FromXml for Feature { type P = impl Parser>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { let mut var = None; let (bytes, end) = match event { Event::Start(bytes) => (bytes, false), @@ -234,8 +234,8 @@ impl FromXml for Feature { return Ok(item); } - let (namespace, event) = yield; - let Event::End(bytes) = event else { + (namespace, event) = yield; + let Event::End(_) = event else { return Err(ffail!("Unexpected XML event: {event:?}")); }; Ok(item) @@ -258,9 +258,9 @@ impl FromXml for ItemQuery { type P = impl Parser>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { let mut item = vec![]; - let (bytes, end) = match event { + let (_, end) = match event { Event::Start(bytes) => (bytes, false), Event::Empty(bytes) => (bytes, true), _ => return Err(ffail!("Unexpected XML event: {event:?}")), @@ -269,7 +269,7 @@ impl FromXml for ItemQuery { return Ok(ItemQuery { item }); } loop { - let (namespace, event) = yield; + (namespace, event) = yield; let bytes = match event { Event::Start(bytes) => bytes, Event::Empty(bytes) => bytes, @@ -296,7 +296,7 @@ impl FromXmlTag for ItemQuery { impl ToXml for ItemQuery { fn serialize(&self, events: &mut Vec>) { - let mut bytes = BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS_ITEM)); + let bytes = BytesStart::new(format!(r#"query xmlns="{}""#, XMLNS_ITEM)); let empty = self.item.is_empty(); if empty { events.push(Event::Empty(bytes)); @@ -342,7 +342,7 @@ impl FromXml for Item { type P = impl Parser>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(_, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { let mut jid = None; let mut name = None; let mut node = None; @@ -378,8 +378,8 @@ impl FromXml for Item { return Ok(item); } - let (namespace, event) = yield; - let Event::End(bytes) = event else { + (_, event) = yield; + let Event::End(_) = event else { return Err(ffail!("Unexpected XML event: {event:?}")); }; Ok(item) diff --git a/crates/proto-xmpp/src/muc/mod.rs b/crates/proto-xmpp/src/muc/mod.rs index 0a6e702..f357dd0 100644 --- a/crates/proto-xmpp/src/muc/mod.rs +++ b/crates/proto-xmpp/src/muc/mod.rs @@ -19,7 +19,7 @@ impl FromXml for History { type P = impl Parser>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { let mut history = History::default(); let (bytes, end) = match event { Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => (bytes, false), @@ -51,7 +51,7 @@ impl FromXml for History { return Ok(history); } - let (namespace, event) = yield; + (namespace, event) = yield; let Event::End(bytes) = event else { return Err(anyhow!("Unexpected XML event: {event:?}")); }; @@ -73,17 +73,17 @@ impl FromXml for Password { type P = impl Parser>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { let bytes = match event { Event::Start(bytes) if bytes.name().0 == Self::NAME.as_bytes() => bytes, _ => return Err(anyhow!("Unexpected XML event: {event:?}")), }; - let (namespace, event) = yield; + (namespace, event) = yield; let Event::Text(bytes) = event else { return Err(anyhow!("Unexpected XML event: {event:?}")); }; let s = std::str::from_utf8(bytes)?.to_string(); - let (namespace, event) = yield; + (namespace, event) = yield; let Event::End(bytes) = event else { return Err(anyhow!("Unexpected XML event: {event:?}")); }; @@ -108,7 +108,7 @@ impl FromXml for X { type P = impl Parser>; fn parse() -> Self::P { - |(namespace, event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { + |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { let mut res = X::default(); let (_, end) = match event { Event::Start(bytes) => (bytes, false), @@ -120,7 +120,7 @@ impl FromXml for X { } loop { - let (namespace, event) = yield; + (namespace, event) = yield; let bytes = match event { Event::Start(bytes) => bytes, Event::Empty(bytes) => bytes, diff --git a/crates/proto-xmpp/src/xml/mod.rs b/crates/proto-xmpp/src/xml/mod.rs index 1919ff2..b928fa0 100644 --- a/crates/proto-xmpp/src/xml/mod.rs +++ b/crates/proto-xmpp/src/xml/mod.rs @@ -89,8 +89,8 @@ macro_rules! delegate_parsing { Continuation::Final(Ok(res)) => break Ok(res.into()), Continuation::Final(Err(err)) => break Err(err), Continuation::Continue(p) => { - let (namespace, event) = yield; - parser = p.consume(namespace, event); + ($namespace, $event) = yield; + parser = p.consume($namespace, $event); } } } diff --git a/src/main.rs b/src/main.rs index d72055b..9b73b1f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -109,7 +109,7 @@ fn set_up_logging(tracing_config: &Option) -> Result<()> { let targets = { use std::{env, str::FromStr}; - use tracing_subscriber::{filter::Targets, layer::SubscriberExt}; + use tracing_subscriber::filter::Targets; match env::var("RUST_LOG") { Ok(var) => Targets::from_str(&var) .map_err(|e| { From c1dc2df150468e498df7e433c2d85db1111697da Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Sun, 28 Apr 2024 17:29:31 +0200 Subject: [PATCH 35/37] xmpp: document xml parsing types --- crates/proto-xmpp/src/xml/mod.rs | 41 ++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/crates/proto-xmpp/src/xml/mod.rs b/crates/proto-xmpp/src/xml/mod.rs index b928fa0..79c85ea 100644 --- a/crates/proto-xmpp/src/xml/mod.rs +++ b/crates/proto-xmpp/src/xml/mod.rs @@ -10,9 +10,38 @@ use anyhow::Result; mod ignore; pub use ignore::Ignore; +/// Types which can be parsed from an XML input stream. +/// +/// Example: +/// ``` +/// #![feature(type_alias_impl_trait)] +/// #![feature(impl_trait_in_assoc_type)] +/// #![feature(coroutines)] +/// # use proto_xmpp::xml::FromXml; +/// # use quick_xml::events::Event; +/// # use quick_xml::name::ResolveResult; +/// # use proto_xmpp::xml::Parser; +/// # use anyhow::Result; +/// +/// struct MyStruct; +/// impl FromXml for MyStruct { +/// type P = impl Parser>; +/// +/// fn parse() -> Self::P { +/// |(mut namespace, mut event): (ResolveResult<'static>, &'static Event<'static>)| -> Result { +/// (namespace, event) = yield; +/// Ok(MyStruct) +/// } +/// } +/// } +/// ``` pub trait FromXml: Sized { + /// The type of parser instances. + /// + /// If the result type of the [parse] is anonymous, this type member can be defined by using `impl Trait`. type P: Parser>; + /// Creates a new instance of a parser with an initial state. fn parse() -> Self::P; } @@ -25,9 +54,18 @@ pub trait FromXmlTag: FromXml { const NS: &'static str; } +/// A stateful parser instance which consumes XML events until the parsing is complete. +/// +/// Usually implemented with the experimental coroutine syntax, which yields to consume the next XML event, +/// and returns the final result when the parsing is done. pub trait Parser: Sized { type Output; + /// Advance the parsing by one XML event. + /// + /// This method consumes `self`, but if the parsing is incomplete, + /// it will return the next state of the parser in the returned result. + /// Otherwise, it will return the final result of parsing. fn consume<'a>(self: Self, namespace: ResolveResult, event: &Event<'a>) -> Continuation; } @@ -50,8 +88,11 @@ where } } +/// The result of a single parser iteration. pub enum Continuation { + /// The parsing is complete and the final result is available. Final(Res), + /// The parsing is not complete and more XML events are required. Continue(Parser), } From 31f9da9b05d53e15d6b6e317489e1aca90310b0c Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Mon, 29 Apr 2024 19:13:32 +0200 Subject: [PATCH 36/37] xmpp: fix incorrect auth test --- crates/projection-xmpp/tests/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 8c05128..dd537a1 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -299,7 +299,7 @@ async fn scenario_wrong_password() -> Result<()> { assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"not-authorized")); assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"failure")); - stream.shutdown().await?; + let _ = stream.shutdown().await; // wrap up From 25605322a0009b8f5fc86100c449ef3557d1f196 Mon Sep 17 00:00:00 2001 From: Nikita Vilunov Date: Mon, 29 Apr 2024 17:24:43 +0000 Subject: [PATCH 37/37] player shutdown API (#58) Reviewed-on: https://git.vilunov.me/lavina/lavina/pulls/58 --- crates/lavina-core/src/player.rs | 38 ++++++++++++++++++--- crates/mgmt-api/src/lib.rs | 6 ++++ crates/projection-irc/src/lib.rs | 17 +++++++--- crates/projection-xmpp/src/lib.rs | 46 +++++++++++++++++++------ crates/proto-xmpp/src/lib.rs | 1 + crates/proto-xmpp/src/streamerror.rs | 41 ++++++++++++++++++++++ src/http.rs | 51 ++++++++++++++++++++++------ 7 files changed, 170 insertions(+), 30 deletions(-) create mode 100644 crates/proto-xmpp/src/streamerror.rs diff --git a/crates/lavina-core/src/player.rs b/crates/lavina-core/src/player.rs index 6dc65de..30635b0 100644 --- a/crates/lavina-core/src/player.rs +++ b/crates/lavina-core/src/player.rs @@ -55,7 +55,7 @@ pub struct ConnectionId(pub AnonKey); /// The connection is used to send commands to the player actor and to receive updates that might be sent to the client. pub struct PlayerConnection { pub connection_id: ConnectionId, - pub receiver: Receiver, + pub receiver: Receiver, player_handle: PlayerHandle, } impl PlayerConnection { @@ -160,7 +160,7 @@ impl PlayerHandle { enum ActorCommand { /// Establish a new connection. AddConnection { - sender: Sender, + sender: Sender, promise: Promise, }, /// Terminate an existing connection. @@ -276,11 +276,27 @@ impl PlayerRegistry { Ok(()) } + #[tracing::instrument(skip(self), name = "PlayerRegistry::get_player")] pub async fn get_player(&self, id: &PlayerId) -> Option { let inner = self.0.read().await; inner.players.get(id).map(|(handle, _)| handle.clone()) } + #[tracing::instrument(skip(self), name = "PlayerRegistry::stop_player")] + pub async fn stop_player(&self, id: &PlayerId) -> Result> { + let mut inner = self.0.write().await; + if let Some((handle, fiber)) = inner.players.remove(id) { + handle.send(ActorCommand::Stop).await; + drop(handle); + fiber.await?; + inner.metric_active_players.dec(); + Ok(Some(())) + } else { + Ok(None) + } + } + + #[tracing::instrument(skip(self), name = "PlayerRegistry::get_or_launch_player")] pub async fn get_or_launch_player(&mut self, id: &PlayerId) -> PlayerHandle { let inner = self.0.read().await; if let Some((handle, _)) = inner.players.get(id) { @@ -305,6 +321,7 @@ impl PlayerRegistry { } } + #[tracing::instrument(skip(self), name = "PlayerRegistry::connect_to_player")] pub async fn connect_to_player(&mut self, id: &PlayerId) -> PlayerConnection { let player_handle = self.get_or_launch_player(id).await; player_handle.subscribe().await @@ -337,7 +354,7 @@ struct PlayerRegistryInner { struct Player { player_id: PlayerId, storage_id: u32, - connections: AnonTable>, + connections: AnonTable>, my_rooms: HashMap, banned_from: HashSet, rx: Receiver<(ActorCommand, Span)>, @@ -438,7 +455,7 @@ impl Player { _ => {} } for (_, connection) in &self.connections { - let _ = connection.send(update.clone()).await; + let _ = connection.send(ConnectionMessage::Update(update.clone())).await; } } @@ -589,7 +606,18 @@ impl Player { if ConnectionId(a) == except { continue; } - let _ = b.send(update.clone()).await; + let _ = b.send(ConnectionMessage::Update(update.clone())).await; } } } + +pub enum ConnectionMessage { + Update(Updates), + Stop(StopReason), +} + +#[derive(Debug)] +pub enum StopReason { + ServerShutdown, + InternalError, +} diff --git a/crates/mgmt-api/src/lib.rs b/crates/mgmt-api/src/lib.rs index cfe5b69..c21ff85 100644 --- a/crates/mgmt-api/src/lib.rs +++ b/crates/mgmt-api/src/lib.rs @@ -11,6 +11,11 @@ pub struct CreatePlayerRequest<'a> { pub name: &'a str, } +#[derive(Serialize, Deserialize)] +pub struct StopPlayerRequest<'a> { + pub name: &'a str, +} + #[derive(Serialize, Deserialize)] pub struct ChangePasswordRequest<'a> { pub player_name: &'a str, @@ -19,6 +24,7 @@ pub struct ChangePasswordRequest<'a> { pub mod paths { pub const CREATE_PLAYER: &'static str = "/mgmt/create_player"; + pub const STOP_PLAYER: &'static str = "/mgmt/stop_player"; pub const SET_PASSWORD: &'static str = "/mgmt/set_password"; } diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index a320c5d..2a310a1 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -507,11 +507,18 @@ async fn handle_registered_socket<'a>( buffer.clear(); }, update = connection.receiver.recv() => { - if let Some(update) = update { - handle_update(&config, &user, &player_id, writer, &rooms, update).await?; - } else { - log::warn!("Player is terminated, must terminate the connection"); - break; + match update { + Some(ConnectionMessage::Update(update)) => { + handle_update(&config, &user, &player_id, writer, &rooms, update).await?; + } + Some(ConnectionMessage::Stop(_)) => { + tracing::debug!("Connection is being terminated"); + break; + } + None => { + log::warn!("Player is terminated, must terminate the connection"); + break; + } } } } diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index fe56481..eec6fc3 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -23,7 +23,7 @@ use tokio_rustls::rustls::{Certificate, PrivateKey}; use tokio_rustls::TlsAcceptor; use lavina_core::auth::{Authenticator, Verdict}; -use lavina_core::player::{PlayerConnection, PlayerId, PlayerRegistry}; +use lavina_core::player::{ConnectionMessage, PlayerConnection, PlayerId, PlayerRegistry, StopReason}; use lavina_core::prelude::*; use lavina_core::repo::Storage; use lavina_core::room::RoomRegistry; @@ -31,6 +31,7 @@ use lavina_core::terminator::Terminator; use lavina_core::LavinaCore; use proto_xmpp::bind::{Name, Resource}; use proto_xmpp::stream::*; +use proto_xmpp::streamerror::{StreamError, StreamErrorKind}; use proto_xmpp::xml::{Continuation, FromXml, Parser, ToXml}; use sasl::AuthBody; @@ -395,16 +396,41 @@ async fn socket_final( true }, update = conn.user_handle.receiver.recv() => { - if let Some(update) = update { - conn.handle_update(&mut events, update).await?; - for i in &events { - xml_writer.write_event_async(i).await?; + match update { + Some(ConnectionMessage::Update(update)) => { + conn.handle_update(&mut events, update).await?; + for i in &events { + xml_writer.write_event_async(i).await?; + } + events.clear(); + xml_writer.get_mut().flush().await?; + } + Some(ConnectionMessage::Stop(reason)) => { + tracing::debug!("Connection is being terminated: {reason:?}"); + let kind = match reason { + StopReason::ServerShutdown => StreamErrorKind::SystemShutdown, + StopReason::InternalError => StreamErrorKind::InternalServerError, + }; + StreamError { kind }.serialize(&mut events); + ServerStreamEnd.serialize(&mut events); + for i in &events { + xml_writer.write_event_async(i).await?; + } + events.clear(); + xml_writer.get_mut().flush().await?; + break; + } + None => { + log::error!("Player is terminated, must terminate the connection"); + StreamError { kind: StreamErrorKind::SystemShutdown }.serialize(&mut events); + ServerStreamEnd.serialize(&mut events); + for i in &events { + xml_writer.write_event_async(i).await?; + } + events.clear(); + xml_writer.get_mut().flush().await?; + break; } - events.clear(); - xml_writer.get_mut().flush().await?; - } else { - log::warn!("Player is terminated, must terminate the connection"); - break; } false } diff --git a/crates/proto-xmpp/src/lib.rs b/crates/proto-xmpp/src/lib.rs index 1e97a31..d3e25ba 100644 --- a/crates/proto-xmpp/src/lib.rs +++ b/crates/proto-xmpp/src/lib.rs @@ -10,6 +10,7 @@ pub mod sasl; pub mod session; pub mod stanzaerror; pub mod stream; +pub mod streamerror; pub mod tls; pub mod xml; diff --git a/crates/proto-xmpp/src/streamerror.rs b/crates/proto-xmpp/src/streamerror.rs new file mode 100644 index 0000000..0ba71cd --- /dev/null +++ b/crates/proto-xmpp/src/streamerror.rs @@ -0,0 +1,41 @@ +use crate::xml::ToXml; +use quick_xml::events::{BytesEnd, BytesStart, Event}; + +/// Stream error condition +/// +/// [Spec](https://xmpp.org/rfcs/rfc6120.html#streams-error-conditions). +pub enum StreamErrorKind { + /// The server has experienced a misconfiguration or other internal error that prevents it from servicing the stream. + InternalServerError, + /// The server is being shut down and all active streams are being closed. + SystemShutdown, +} +impl StreamErrorKind { + pub fn from_str(s: &str) -> Option { + match s { + "internal-server-error" => Some(Self::InternalServerError), + "system-shutdown" => Some(Self::SystemShutdown), + _ => None, + } + } + pub fn as_str(&self) -> &'static str { + match self { + Self::InternalServerError => "internal-server-error", + Self::SystemShutdown => "system-shutdown", + } + } +} + +pub struct StreamError { + pub kind: StreamErrorKind, +} +impl ToXml for StreamError { + fn serialize(&self, events: &mut Vec>) { + events.push(Event::Start(BytesStart::new("stream:error"))); + events.push(Event::Empty(BytesStart::new(format!( + r#"{} xmlns="urn:ietf:params:xml:ns:xmpp-streams""#, + self.kind.as_str() + )))); + events.push(Event::End(BytesEnd::new("stream:error"))); + } +} diff --git a/src/http.rs b/src/http.rs index b39a6f2..ae64676 100644 --- a/src/http.rs +++ b/src/http.rs @@ -13,6 +13,7 @@ use serde::{Deserialize, Serialize}; use tokio::net::TcpListener; use lavina_core::auth::{Authenticator, UpdatePasswordResult}; +use lavina_core::player::{PlayerId, PlayerRegistry}; use lavina_core::prelude::*; use lavina_core::repo::Storage; use lavina_core::room::RoomRegistry; @@ -85,8 +86,9 @@ async fn route( (&Method::GET, "/metrics") => endpoint_metrics(registry), (&Method::GET, "/rooms") => endpoint_rooms(core.rooms).await, (&Method::POST, paths::CREATE_PLAYER) => endpoint_create_player(request, storage).await.or5xx(), + (&Method::POST, paths::STOP_PLAYER) => endpoint_stop_player(request, core.players).await.or5xx(), (&Method::POST, paths::SET_PASSWORD) => endpoint_set_password(request, storage).await.or5xx(), - _ => not_found(), + _ => endpoint_not_found(), }; Ok(res) } @@ -98,6 +100,7 @@ fn endpoint_metrics(registry: MetricsRegistry) -> Response> { Response::new(Full::new(Bytes::from(buffer))) } +#[tracing::instrument(skip_all)] async fn endpoint_rooms(rooms: RoomRegistry) -> Response> { // TODO introduce management API types independent from core-domain types // TODO remove `Serialize` implementations from all core-domain types @@ -105,6 +108,7 @@ async fn endpoint_rooms(rooms: RoomRegistry) -> Response> { Response::new(room_list) } +#[tracing::instrument(skip_all)] async fn endpoint_create_player( request: Request, mut storage: Storage, @@ -120,6 +124,27 @@ async fn endpoint_create_player( Ok(response) } +#[tracing::instrument(skip_all)] +async fn endpoint_stop_player( + request: Request, + players: PlayerRegistry, +) -> Result>> { + let str = request.collect().await?.to_bytes(); + let Ok(res) = serde_json::from_slice::(&str[..]) else { + return Ok(malformed_request()); + }; + let Ok(player_id) = PlayerId::from(res.name) else { + return Ok(player_not_found()); + }; + let Some(()) = players.stop_player(&player_id).await? else { + return Ok(player_not_found()); + }; + let mut response = Response::new(Full::::default()); + *response.status_mut() = StatusCode::NO_CONTENT; + Ok(response) +} + +#[tracing::instrument(skip_all)] async fn endpoint_set_password( request: Request, storage: Storage, @@ -132,14 +157,7 @@ async fn endpoint_set_password( match verdict { UpdatePasswordResult::PasswordUpdated => {} UpdatePasswordResult::UserNotFound => { - let payload = ErrorResponse { - code: errors::PLAYER_NOT_FOUND, - message: "No such player exists", - } - .to_body(); - let mut response = Response::new(payload); - *response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY; - return Ok(response); + return Ok(player_not_found()); } } let mut response = Response::new(Full::::default()); @@ -147,7 +165,7 @@ async fn endpoint_set_password( Ok(response) } -pub fn not_found() -> Response> { +fn endpoint_not_found() -> Response> { let payload = ErrorResponse { code: errors::INVALID_PATH, message: "The path does not exist", @@ -159,6 +177,17 @@ pub fn not_found() -> Response> { response } +fn player_not_found() -> Response> { + let payload = ErrorResponse { + code: errors::PLAYER_NOT_FOUND, + message: "No such player exists", + } + .to_body(); + let mut response = Response::new(payload); + *response.status_mut() = StatusCode::UNPROCESSABLE_ENTITY; + response +} + fn malformed_request() -> Response> { let payload = ErrorResponse { code: errors::MALFORMED_REQUEST, @@ -174,6 +203,7 @@ fn malformed_request() -> Response> { trait Or5xx { fn or5xx(self) -> Response>; } + impl Or5xx for Result>> { fn or5xx(self) -> Response> { self.unwrap_or_else(|e| { @@ -187,6 +217,7 @@ impl Or5xx for Result>> { trait ToBody { fn to_body(&self) -> Full; } + impl ToBody for T where T: Serialize,