lavina/src/http/ws.rs

103 lines
3.4 KiB
Rust
Raw Normal View History

2023-01-25 12:50:14 +00:00
use futures_util::TryStreamExt;
use http_body_util::Empty;
use hyper::header::{
CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE,
};
use hyper::http::HeaderValue;
use hyper::upgrade::Upgraded;
use hyper::{body::Bytes, Request, Response};
use hyper::{Method, StatusCode, Version};
use std::convert::Infallible;
use tokio_tungstenite::tungstenite::handshake::derive_accept_key;
use tokio_tungstenite::tungstenite::protocol::Role;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::WebSocketStream;
use futures_util::sink::SinkExt;
use futures_util::stream::StreamExt;
2023-01-27 20:43:20 +00:00
async fn handle_connection(ws_stream: WebSocketStream<Upgraded>) {
2023-01-25 12:50:14 +00:00
tracing::info!("WebSocket connection established");
let (mut outgoing, incoming) = ws_stream.split();
let broadcast_incoming = incoming.try_for_each(|msg| {
2023-01-27 20:43:20 +00:00
tracing::info!("Received a message: {}", msg.to_text().unwrap());
2023-01-25 12:50:14 +00:00
async { Ok(()) }
});
outgoing.send(Message::Text("adsads".into())).await.unwrap();
match broadcast_incoming.await {
Ok(_) => tracing::info!("Disconnected"),
Err(e) => tracing::warn!("Socket failed: {}", e),
}
}
pub async fn handle_request(
mut req: Request<hyper::body::Incoming>,
) -> std::result::Result<Response<Empty<Bytes>>, Infallible> {
dbg!(&req);
println!("Received a new, potentially ws handshake");
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
for (ref header, _value) in req.headers() {
println!("* {}", header);
}
let upgrade = HeaderValue::from_static("Upgrade");
let websocket = HeaderValue::from_static("websocket");
let headers = req.headers();
let key = headers.get(SEC_WEBSOCKET_KEY);
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
if req.method() != Method::GET
|| req.version() < Version::HTTP_11
|| !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap()))
})
.unwrap_or(false)
|| !headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
|| !headers
.get(SEC_WEBSOCKET_VERSION)
.map(|h| h == "13")
.unwrap_or(false)
|| key.is_none()
|| req.uri() != "/socket"
{
dbg!();
let mut resp = Response::new(Empty::new());
*resp.status_mut() = StatusCode::BAD_REQUEST;
return Ok(resp);
}
let ver = req.version();
tokio::task::spawn(async move {
match hyper::upgrade::on(&mut req).await {
Ok(upgraded) => {
handle_connection(
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await,
)
.await;
}
Err(e) => println!("upgrade error: {}", e),
}
});
let mut res = Response::new(Empty::new());
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
*res.version_mut() = ver;
res.headers_mut().append(CONNECTION, upgrade);
res.headers_mut().append(UPGRADE, websocket);
res.headers_mut()
.append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap());
Ok(res)
}