diff --git a/crates/projection-irc/src/lib.rs b/crates/projection-irc/src/lib.rs index 82c3d29..ff3d892 100644 --- a/crates/projection-irc/src/lib.rs +++ b/crates/projection-irc/src/lib.rs @@ -181,19 +181,15 @@ async fn handle_registration<'a>( } CapabilitySubcommand::End => { let Some((ref username, ref realname)) = future_username else { - sasl_fail_message(config.server_name.clone()).write_async(writer).await?; - writer.flush().await?; - continue; + break Err(anyhow::Error::msg("Protocol violated")); }; let Some(nickname) = future_nickname.clone() else { - sasl_fail_message(config.server_name.clone()).write_async(writer).await?; - writer.flush().await?; - continue; + break Err(anyhow::Error::msg("Protocol violated")); }; let username = username.clone(); let realname = realname.clone(); let candidate_user = RegisteredUser { - nickname, + nickname: nickname.clone(), username, realname }; @@ -203,7 +199,11 @@ async fn handle_registration<'a>( break Ok(candidate_user); } else { let Some(candidate_password) = pass else { - sasl_fail_message(config.server_name.clone()).write_async(writer).await?; + sasl_fail_message( + config.server_name.clone(), + nickname.clone(), + "User credentials was not provided".into() + ).write_async(writer).await?; writer.flush().await?; continue; }; @@ -217,12 +217,16 @@ async fn handle_registration<'a>( future_nickname = Some(nickname); } else if let Some((username, realname)) = future_username.clone() { let candidate_user = RegisteredUser { - nickname, + nickname: nickname.clone(), username, realname, }; let Some(candidate_password) = pass else { - sasl_fail_message(config.server_name.clone()).write_async(writer).await?; + sasl_fail_message( + config.server_name.clone(), + nickname.clone(), + "User credentials was not provided".into() + ).write_async(writer).await?; writer.flush().await?; continue; }; @@ -237,12 +241,16 @@ async fn handle_registration<'a>( future_username = Some((username, realname)); } else if let Some(nickname) = future_nickname.clone() { let candidate_user = RegisteredUser { - nickname, + nickname: nickname.clone(), username, realname, }; let Some(candidate_password) = pass else { - sasl_fail_message(config.server_name.clone()).write_async(writer).await?; + sasl_fail_message( + config.server_name.clone(), + nickname.clone(), + "User credentials was not provided".into() + ).write_async(writer).await?; writer.flush().await?; continue; }; @@ -267,46 +275,52 @@ async fn handle_registration<'a>( .await?; writer.flush().await?; } else { - sasl_fail_message(config.server_name.clone()).write_async(writer).await?; - writer.flush().await?; + if let Some(nickname) = future_nickname.clone() { + sasl_fail_message(config.server_name.clone(), nickname.clone(), "Unsupported mechanism".into()).write_async(writer).await?; + writer.flush().await?; + } else { + break Err(anyhow::Error::msg("Wrong authentication sequence")); + } } } else { let body = AuthBody::from_str(body.as_bytes())?; - match auth_user(storage, &body.login, &body.password).await { - Err(e) => { - tracing::warn!("Authentication failed: {:?}", e); - sasl_fail_message(config.server_name.clone()).write_async(writer).await?; + if let Err(e) = auth_user(storage, &body.login, &body.password).await { + tracing::warn!("Authentication failed: {:?}", e); + if let Some(nickname) = future_nickname.clone() { + sasl_fail_message(config.server_name.clone(), nickname.clone(), "Bad credentials".into()).write_async(writer).await?; writer.flush().await?; + } else { + } - Ok(_) => { - let login: Str = body.login.into(); - validated_user = Some(login.clone()); - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::N900LoggedIn { - nick: login.clone(), - address: login.clone(), - account: login.clone(), - message: format!("You are now logged in as {}", login).into(), - }, - } - .write_async(writer) - .await?; - ServerMessage { - tags: vec![], - sender: Some(config.server_name.clone().into()), - body: ServerMessageBody::N903SaslSuccess { - nick: login.clone(), - message: "SASL authentication successful".into(), - }, - } - .write_async(writer) - .await?; - writer.flush().await?; + } else { + let login: Str = body.login.into(); + validated_user = Some(login.clone()); + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N900LoggedIn { + nick: login.clone(), + address: login.clone(), + account: login.clone(), + message: format!("You are now logged in as {}", login).into(), + }, } + .write_async(writer) + .await?; + ServerMessage { + tags: vec![], + sender: Some(config.server_name.clone().into()), + body: ServerMessageBody::N903SaslSuccess { + nick: login.clone(), + message: "SASL authentication successful".into(), + }, + } + .write_async(writer) + .await?; + writer.flush().await?; } } + // TODO handle abortion of authentication } _ => {} @@ -317,11 +331,11 @@ async fn handle_registration<'a>( Ok(user) } -fn sasl_fail_message(sender: Str) -> ServerMessage { +fn sasl_fail_message(sender: Str, nick: Str, text: Str) -> ServerMessage { ServerMessage { tags: vec![], sender: Some(sender), - body: ServerMessageBody::N904SaslFail + body: ServerMessageBody::N904SaslFail { nick, text } } } diff --git a/crates/projection-irc/tests/lib.rs b/crates/projection-irc/tests/lib.rs index 7ddf427..1a686a4 100644 --- a/crates/projection-irc/tests/lib.rs +++ b/crates/projection-irc/tests/lib.rs @@ -241,6 +241,8 @@ async fn scenario_cap_sasl_fail() -> Result<()> { s.expect(":testserver 904").await?; s.send("AUTHENTICATE PLAIN").await?; s.expect(":testserver AUTHENTICATE +").await?; + s.send("AUTHENTICATE wrong_password").await?; + s.expect(":testserver 904").await?; s.send("AUTHENTICATE dGVzdGVyAHRlc3RlcgBwYXNzd29yZA==").await?; // base64-encoded 'tester\x00tester\x00password' s.expect(":testserver 900 tester tester tester :You are now logged in as tester").await?; s.expect(":testserver 903 tester :SASL authentication successful").await?; diff --git a/crates/proto-irc/src/server.rs b/crates/proto-irc/src/server.rs index aff864d..4c1836c 100644 --- a/crates/proto-irc/src/server.rs +++ b/crates/proto-irc/src/server.rs @@ -161,7 +161,10 @@ pub enum ServerMessageBody { nick: Str, message: Str, }, - N904SaslFail + N904SaslFail { + nick: Str, + text: Str, + } } impl ServerMessageBody { @@ -393,8 +396,12 @@ impl ServerMessageBody { writer.write_all(b" :").await?; writer.write_all(message.as_bytes()).await?; } - ServerMessageBody::N904SaslFail => { + ServerMessageBody::N904SaslFail { nick, text } => { writer.write_all(b"904").await?; + writer.write_all(b" ").await?; + writer.write_all(nick.as_bytes()).await?; + writer.write_all(b" :").await?; + writer.write_all(text.as_bytes()).await?; } } Ok(())