diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index 4808b85..182d28b 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -54,7 +54,7 @@ async fn handle_socket( socket_addr: &SocketAddr, players: PlayerRegistry, rooms: RoomRegistry, - termination: Deferred<()>, // TODO use it to stop the connection gracefully + termination: Deferred<()>, mut storage: Storage, ) -> Result<()> { log::info!("Received an IRC connection from {socket_addr}"); @@ -62,19 +62,24 @@ async fn handle_socket( let mut reader: BufReader = BufReader::new(reader); let mut writer = BufWriter::new(writer); - let registered_user: Result = - handle_registration(&mut reader, &mut writer, &mut storage, &config).await; - - match registered_user { - Ok(user) => { - log::debug!("User registered"); - handle_registered_socket(config, players, rooms, &mut reader, &mut writer, user).await?; - } - Err(err) => { - log::debug!("Registration failed: {err}"); - } + pin!(termination); + select! { + biased; + _ = &mut termination =>{ + log::info!("Socket handling was terminated"); + return Ok(()) + }, + registered_user = handle_registration(&mut reader, &mut writer, &mut storage, &config) => + match registered_user { + Ok(user) => { + log::debug!("User registered"); + handle_registered_socket(config, players, rooms, &mut reader, &mut writer, user).await?; + } + Err(err) => { + log::debug!("Registration failed: {err}"); + } + } } - stream.shutdown().await?; Ok(()) } diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index a0ee071..5becdae 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -1,3 +1,4 @@ +use std::net::SocketAddr; use std::time::Duration; use anyhow::{anyhow, Result}; @@ -220,8 +221,9 @@ async fn scenario_cap_short_negotiation() -> Result<()> { } #[tokio::test] -async fn scenario_cap_sasl_fail() -> Result<()> { +async fn terminate_socket_scenario() -> Result<()> { let mut server = TestServer::start().await?; + let address: SocketAddr = ("127.0.0.1:0".parse().unwrap()); // test scenario @@ -231,38 +233,16 @@ async fn scenario_cap_sasl_fail() -> Result<()> { let mut stream = TcpStream::connect(server.server.addr).await?; let mut s = TestScope::new(&mut stream); - s.send("CAP LS 302").await?; s.send("NICK tester").await?; - s.send("USER UserName 0 * :Real Name").await?; - s.expect(":testserver CAP * LS :sasl=PLAIN").await?; s.send("CAP REQ :sasl").await?; + s.send("USER UserName 0 * :Real Name").await?; s.expect(":testserver CAP tester ACK :sasl").await?; - s.send("AUTHENTICATE SHA256").await?; - s.expect(":testserver 904 tester :Unsupported mechanism").await?; s.send("AUTHENTICATE PLAIN").await?; s.expect(":testserver AUTHENTICATE +").await?; - s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZDE=").await?; - s.expect(":testserver 904 tester :Bad credentials").await?; - s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password' - s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?; - s.expect(":testserver 903 tester :SASL authentication successful").await?; - - s.send("CAP END").await?; - - s.expect(":testserver 001 tester :Welcome to Kek Server").await?; - s.expect(":testserver 002 tester :Welcome to Kek Server").await?; - s.expect(":testserver 003 tester :Welcome to Kek Server").await?; - s.expect(":testserver 004 tester testserver kek-0.1.alpha.3 r CFILPQbcefgijklmnopqrstvz").await?; - s.expect(":testserver 005 tester CHANTYPES=# :are supported by this server").await?; - s.expect_nothing().await?; - s.send("QUIT :Leaving").await?; - s.expect(":testserver ERROR :Leaving the server").await?; - s.expect_eof().await?; stream.shutdown().await?; - - // wrap up - server.server.terminate().await?; + + assert!(TcpStream::connect(&address).await.is_err()); Ok(()) } diff --git a/crates/projection-xmpp/src/lib.rs b/crates/projection-xmpp/src/lib.rs index 9e852ca..84fe721 100644 --- a/crates/projection-xmpp/src/lib.rs +++ b/crates/projection-xmpp/src/lib.rs @@ -162,7 +162,7 @@ async fn handle_socket( mut players: PlayerRegistry, rooms: RoomRegistry, mut storage: Storage, - termination: Deferred<()>, // TODO use it to stop the connection gracefully + termination: Deferred<()>, ) -> Result<()> { log::info!("Received an XMPP connection from {socket_addr}"); let mut reader_buf = vec![]; @@ -187,18 +187,34 @@ async fn handle_socket( let mut xml_reader = NsReader::from_reader(BufReader::new(a)); let mut xml_writer = Writer::new(b); - let authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage).await?; - log::debug!("User authenticated"); - let mut connection = players.connect_to_player(authenticated.player_id.clone()).await; - socket_final( - &mut xml_reader, - &mut xml_writer, - &mut reader_buf, - &authenticated, - &mut connection, - &rooms, - ) - .await?; + pin!(termination); + select! { + biased; + _ = &mut termination =>{ + log::info!("Socket handling was terminated"); + return Ok(()) + }, + authenticated = socket_auth(&mut xml_reader, &mut xml_writer, &mut reader_buf, &mut storage) => { + match authenticated { + Ok(authenticated) => { + let mut connection = players.connect_to_player(authenticated.player_id.clone()).await; + socket_final( + &mut xml_reader, + &mut xml_writer, + &mut reader_buf, + &authenticated, + &mut connection, + &rooms, + ) + .await?; + }, + Err(err) => { + log::error!("Authentication error: {:?}", err); + } + } + }, + } + let a = xml_reader.into_inner().into_inner(); let b = xml_writer.into_inner(); diff --git a/crates/projection-xmpp/tests/lib.rs b/crates/projection-xmpp/tests/lib.rs index 39675eb..9ce216b 100644 --- a/crates/projection-xmpp/tests/lib.rs +++ b/crates/projection-xmpp/tests/lib.rs @@ -1,3 +1,4 @@ +use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -184,3 +185,69 @@ async fn scenario_basic() -> Result<()> { server.terminate().await?; Ok(()) } + + +#[tokio::test] +async fn terminate_socket() -> Result<()> { + tracing_subscriber::fmt::init(); + let config = ServerConfig { + listen_on: "127.0.0.1:0".parse().unwrap(), + cert: "tests/certs/xmpp.pem".parse().unwrap(), + key: "tests/certs/xmpp.key".parse().unwrap(), + }; + let mut metrics = MetricsRegistry::new(); + let mut storage = Storage::open(StorageConfig { + db_path: ":memory:".into(), + }) + .await?; + let rooms = RoomRegistry::new(&mut metrics, storage.clone()).unwrap(); + let players = PlayerRegistry::empty(rooms.clone(), &mut metrics).unwrap(); + let server = launch(config, players, rooms, metrics, storage.clone()).await.unwrap(); + let address: SocketAddr = ("127.0.0.1:0".parse().unwrap()); + // test scenario + + storage.create_user("tester").await?; + storage.set_password("tester", "password").await?; + + let mut stream = TcpStream::connect(server.addr).await?; + let mut s = TestScope::new(&mut stream); + tracing::info!("TCP connection established"); + + s.send(r#""#).await?; + + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"features")); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"required")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"starttls")); + assert_matches!(s.next_xml_event().await?, Event::End(b) => assert_eq!(b.local_name().into_inner(), b"features")); + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Empty(b) => assert_eq!(b.local_name().into_inner(), b"proceed")); + let buffer = s.buffer; + + let connector = TlsConnector::from(Arc::new( + ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(Arc::new(IgnoreCertVerification)) + .with_no_client_auth(), + )); + tracing::info!("Initiating TLS connection..."); + let mut stream = connector.connect(ServerName::IpAddress(server.addr.ip()), stream).await?; + tracing::info!("TLS connection established"); + + let mut s = TestScopeTls::new(&mut stream, buffer); + + s.send(r#""#).await?; + s.send(r#""#).await?; + assert_matches!(s.next_xml_event().await?, Event::Decl(_) => {}); + assert_matches!(s.next_xml_event().await?, Event::Start(b) => assert_eq!(b.local_name().into_inner(), b"stream")); + + stream.shutdown().await?; + server.terminate().await?; + + assert!(TcpStream::connect(&address).await.is_err()); + + Ok(()) +}