use futures_util::TryStreamExt; use http_body_util::Empty; use hyper::header::{ CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE, }; use hyper::http::HeaderValue; use hyper::upgrade::Upgraded; use hyper::{body::Bytes, Request, Response}; use hyper::{Method, StatusCode, Version}; use std::convert::Infallible; use tokio_tungstenite::tungstenite::handshake::derive_accept_key; use tokio_tungstenite::tungstenite::protocol::Role; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::WebSocketStream; use futures_util::sink::SinkExt; use futures_util::stream::StreamExt; async fn handle_connection(ws_stream: WebSocketStream) { tracing::info!("WebSocket connection established"); let (mut outgoing, incoming) = ws_stream.split(); let broadcast_incoming = incoming.try_for_each(|msg| { tracing::info!("Received a message: {}", msg.to_text().unwrap()); async { Ok(()) } }); outgoing.send(Message::Text("adsads".into())).await.unwrap(); match broadcast_incoming.await { Ok(_) => tracing::info!("Disconnected"), Err(e) => tracing::warn!("Socket failed: {}", e), } } pub async fn handle_request( mut req: Request, ) -> std::result::Result>, Infallible> { dbg!(&req); println!("Received a new, potentially ws handshake"); println!("The request's path is: {}", req.uri().path()); println!("The request's headers are:"); for (ref header, _value) in req.headers() { println!("* {}", header); } let upgrade = HeaderValue::from_static("Upgrade"); let websocket = HeaderValue::from_static("websocket"); let headers = req.headers(); let key = headers.get(SEC_WEBSOCKET_KEY); let derived = key.map(|k| derive_accept_key(k.as_bytes())); if req.method() != Method::GET || req.version() < Version::HTTP_11 || !headers .get(CONNECTION) .and_then(|h| h.to_str().ok()) .map(|h| { h.split(|c| c == ' ' || c == ',') .any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap())) }) .unwrap_or(false) || !headers .get(UPGRADE) .and_then(|h| h.to_str().ok()) .map(|h| h.eq_ignore_ascii_case("websocket")) .unwrap_or(false) || !headers .get(SEC_WEBSOCKET_VERSION) .map(|h| h == "13") .unwrap_or(false) || key.is_none() || req.uri() != "/socket" { dbg!(); let mut resp = Response::new(Empty::new()); *resp.status_mut() = StatusCode::BAD_REQUEST; return Ok(resp); } let ver = req.version(); tokio::task::spawn(async move { match hyper::upgrade::on(&mut req).await { Ok(upgraded) => { handle_connection( WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await, ) .await; } Err(e) => println!("upgrade error: {}", e), } }); let mut res = Response::new(Empty::new()); *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; *res.version_mut() = ver; res.headers_mut().append(CONNECTION, upgrade); res.headers_mut().append(UPGRADE, websocket); res.headers_mut() .append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap()); Ok(res) }