From b6d0675976cb6d1e68bc1e3d27829aced9440eb0 Mon Sep 17 00:00:00 2001 From: Noratrieb <48135649+Noratrieb@users.noreply.github.com> Date: Sun, 25 Aug 2024 15:01:22 +0200 Subject: [PATCH] migrate cluelessh-faked to cluelessh-tokio --- Cargo.lock | 10 + bin/cluelessh-faked/Cargo.toml | 1 + bin/cluelessh-faked/src/main.rs | 180 +++++++-------- bin/cluelessh/src/main.rs | 4 +- bin/cluelesshd/Cargo.toml | 9 + bin/cluelesshd/src/main.rs | 3 + lib/cluelessh-connection/src/lib.rs | 12 +- lib/cluelessh-protocol/src/lib.rs | 45 +++- lib/cluelessh-tokio/src/client.rs | 27 +-- lib/cluelessh-tokio/src/lib.rs | 32 +++ lib/cluelessh-tokio/src/server.rs | 327 ++++++++++++++++++++++++++++ 11 files changed, 513 insertions(+), 137 deletions(-) create mode 100644 bin/cluelesshd/Cargo.toml create mode 100644 bin/cluelesshd/src/main.rs create mode 100644 lib/cluelessh-tokio/src/server.rs diff --git a/Cargo.lock b/Cargo.lock index 47f276e..a706997 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -347,6 +347,7 @@ name = "cluelessh-faked" version = "0.1.0" dependencies = [ "cluelessh-protocol", + "cluelessh-tokio", "eyre", "hex-literal", "rand", @@ -425,6 +426,15 @@ dependencies = [ "x25519-dalek", ] +[[package]] +name = "cluelesshd" +version = "0.1.0" +dependencies = [ + "cluelessh-protocol", + "cluelessh-tokio", + "cluelessh-transport", +] + [[package]] name = "colorchoice" version = "1.0.2" diff --git a/bin/cluelessh-faked/Cargo.toml b/bin/cluelessh-faked/Cargo.toml index 49dc8d2..5d5bf96 100644 --- a/bin/cluelessh-faked/Cargo.toml +++ b/bin/cluelessh-faked/Cargo.toml @@ -8,6 +8,7 @@ eyre = "0.6.12" hex-literal = "0.4.1" rand = "0.8.5" cluelessh-protocol = { path = "../../lib/cluelessh-protocol" } +cluelessh-tokio = { path = "../../lib/cluelessh-tokio" } tokio = { version = "1.39.2", features = ["full"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] } diff --git a/bin/cluelessh-faked/src/main.rs b/bin/cluelessh-faked/src/main.rs index 1e00da1..a2eca62 100644 --- a/bin/cluelessh-faked/src/main.rs +++ b/bin/cluelessh-faked/src/main.rs @@ -1,27 +1,19 @@ -use std::{collections::HashMap, net::SocketAddr}; +use std::{net::SocketAddr, sync::Arc}; +use cluelessh_tokio::Channel; use eyre::{Context, Result}; -use rand::RngCore; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, + sync::Mutex, }; -use tracing::{debug, error, info, info_span, Instrument}; +use tracing::{debug, error, info, info_span, warn, Instrument}; use cluelessh_protocol::{ - connection::{ChannelOpen, ChannelOperationKind, ChannelRequest}, - transport::{self}, - ChannelUpdateKind, ServerConnection, SshStatus, + connection::{ChannelKind, ChannelOperationKind, ChannelRequest}, + ChannelUpdateKind, SshStatus, }; use tracing_subscriber::EnvFilter; -struct ThreadRngRand; -impl cluelessh_protocol::transport::SshRng for ThreadRngRand { - fn fill_bytes(&mut self, dest: &mut [u8]) { - rand::thread_rng().fill_bytes(dest); - } -} - #[tokio::main] async fn main() -> eyre::Result<()> { let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); @@ -45,14 +37,16 @@ async fn main() -> eyre::Result<()> { let listener = TcpListener::bind(addr).await.wrap_err("binding listener")?; + let mut listener = cluelessh_tokio::server::ServerListener::new(listener); + loop { let next = listener.accept().await?; - let span = info_span!("connection", addr = %next.1); + let span = info_span!("connection", addr = %next.peer_addr()); tokio::spawn( - async { - let mut total_sent_data = Vec::new(); + async move { + let total_sent_data = Arc::new(Mutex::new(Vec::new())); - if let Err(err) = handle_connection(next, &mut total_sent_data).await { + if let Err(err) = handle_connection(next, total_sent_data.clone()).await { if let Some(err) = err.downcast_ref::() { if err.kind() == std::io::ErrorKind::ConnectionReset { return; @@ -63,6 +57,7 @@ async fn main() -> eyre::Result<()> { } // Limit stdin to 500 characters. + let total_sent_data = total_sent_data.lock().await; let stdin = String::from_utf8_lossy(&total_sent_data); let stdin = if let Some((idx, _)) = stdin.char_indices().nth(500) { &stdin[..idx] @@ -78,46 +73,18 @@ async fn main() -> eyre::Result<()> { } async fn handle_connection( - next: (TcpStream, SocketAddr), - total_sent_data: &mut Vec, + mut conn: cluelessh_tokio::server::ServerConnection, + total_sent_data: Arc>>, ) -> Result<()> { - let (mut conn, addr) = next; - - info!(%addr, "Received a new connection"); - - /*let rng = vec![ - 0x14, 0xa2, 0x04, 0xa5, 0x4b, 0x2f, 0x5f, 0xa7, 0xff, 0x53, 0x13, 0x67, 0x57, 0x67, 0xbc, - 0x55, 0x3f, 0xc0, 0x6c, 0x0d, 0x07, 0x8f, 0xe2, 0x75, 0x95, 0x18, 0x4b, 0xd2, 0xcb, 0xd0, - 0x64, 0x06, 0x14, 0xa2, 0x04, 0xa5, 0x4b, 0x2f, 0x5f, 0xa7, 0xff, 0x53, 0x13, 0x67, 0x57, - 0x67, 0xbc, 0x55, 0x3f, 0xc0, 0x6c, 0x0d, 0x07, 0x8f, 0xe2, 0x75, 0x95, 0x18, 0x4b, 0xd2, - 0xcb, 0xd0, 0x64, 0x06, 0x67, 0xbc, 0x55, 0x3f, 0xc0, 0x6c, 0x0d, 0x07, 0x8f, 0xe2, 0x75, - 0x95, 0x18, 0x4b, 0xd2, 0xcb, 0xd0, 0x64, 0x06, - ]; - struct HardcodedRng(Vec); - impl cluelessh_protocol::transport::SshRng for HardcodedRng { - fn fill_bytes(&mut self, dest: &mut [u8]) { - dest.copy_from_slice(&self.0[..dest.len()]); - self.0.splice(0..dest.len(), []); - } - }*/ - - let mut state = ServerConnection::new(transport::server::ServerConnection::new(ThreadRngRand)); - - let mut session_channels = HashMap::new(); + info!(addr = %conn.peer_addr(), "Received a new connection"); loop { - let mut buf = [0; 1024]; - let read = conn - .read(&mut buf) - .await - .wrap_err("reading from connection")?; - if read == 0 { - info!("Did not read any bytes from TCP stream, EOF"); - return Ok(()); - } - - if let Err(err) = state.recv_bytes(&buf[..read]) { - match err { + match conn.progress().await { + Ok(()) => {} + Err(cluelessh_tokio::server::Error::ServerError(err)) => { + return Err(err); + } + Err(cluelessh_tokio::server::Error::SshStatus(status)) => match status { SshStatus::PeerError(err) => { info!(?err, "disconnecting client after invalid operation"); return Ok(()); @@ -126,28 +93,40 @@ async fn handle_connection( info!("Received disconnect from client"); return Ok(()); } - } + }, } - while let Some(update) = state.next_channel_update() { - //eprintln!("{:?}", update); - match update.kind { - ChannelUpdateKind::Open(kind) => match kind { - ChannelOpen::Session => { - session_channels.insert(update.number, ()); - } - }, + while let Some(channel) = conn.next_new_channel() { + if *channel.kind() == ChannelKind::Session { + let total_sent_data = total_sent_data.clone(); + tokio::spawn(async { + let _ = handle_session_channel(channel, total_sent_data).await; + }); + } else { + warn!("Trying to open non-session channel"); + } + } + } +} + +async fn handle_session_channel( + mut channel: Channel, + total_sent_data: Arc>>, +) -> Result<()> { + loop { + match channel.next_update().await { + Ok(update) => match update { ChannelUpdateKind::Request(req) => { - let success = update.number.construct_op(ChannelOperationKind::Success); + let success = ChannelOperationKind::Success; match req { ChannelRequest::PtyReq { want_reply, .. } => { if want_reply { - state.do_operation(success); + channel.send(success).await?; } } ChannelRequest::Shell { want_reply } => { if want_reply { - state.do_operation(success); + channel.send(success).await?; } } ChannelRequest::Exec { @@ -155,26 +134,20 @@ async fn handle_connection( command, } => { if want_reply { - state.do_operation(success); + channel.send(success).await?; } let result = execute_command(&command); - state.do_operation( - update - .number - .construct_op(ChannelOperationKind::Data(result.stdout)), - ); - state.do_operation(update.number.construct_op( - ChannelOperationKind::Request(ChannelRequest::ExitStatus { + channel + .send(ChannelOperationKind::Data(result.stdout)) + .await?; + channel + .send(ChannelOperationKind::Request(ChannelRequest::ExitStatus { status: result.status, - }), - )); - state.do_operation( - update.number.construct_op(ChannelOperationKind::Eof), - ); - state.do_operation( - update.number.construct_op(ChannelOperationKind::Close), - ); + })) + .await?; + channel.send(ChannelOperationKind::Eof).await?; + channel.send(ChannelOperationKind::Close).await?; } ChannelRequest::ExitStatus { .. } => {} ChannelRequest::Env { .. } => {} @@ -185,44 +158,37 @@ async fn handle_connection( let is_eof = data.contains(&0x04 /*EOF, Ctrl-D*/); // echo :3 - state.do_operation( - update - .number - .construct_op(ChannelOperationKind::Data(data.clone())), - ); + channel + .send(ChannelOperationKind::Data(data.clone())) + .await?; + let mut total_sent_data = total_sent_data.lock().await; // arbitrary limit if total_sent_data.len() < 50_000 { total_sent_data.extend_from_slice(&data); } else { - info!(channel = %update.number, "Reached stdin limit"); - state.do_operation(update.number.construct_op(ChannelOperationKind::Data( - b"Thanks Hayley!\n".to_vec(), - ))); - state.do_operation(update.number.construct_op(ChannelOperationKind::Close)); + info!("Reached stdin limit"); + channel + .send(ChannelOperationKind::Data(b"Thanks Hayley!\n".to_vec())) + .await?; + channel.send(ChannelOperationKind::Close).await?; } if is_eof { - debug!(channel = %update.number, "Received Ctrl-D, closing channel"); + debug!("Received Ctrl-D, closing channel"); - state.do_operation(update.number.construct_op(ChannelOperationKind::Eof)); - state.do_operation(update.number.construct_op(ChannelOperationKind::Close)); + channel.send(ChannelOperationKind::Eof).await?; + channel.send(ChannelOperationKind::Close).await?; } } - ChannelUpdateKind::ExtendedData { .. } + ChannelUpdateKind::Open(_) + | ChannelUpdateKind::Closed + | ChannelUpdateKind::ExtendedData { .. } | ChannelUpdateKind::Eof | ChannelUpdateKind::Success | ChannelUpdateKind::Failure => { /* ignore */ } - ChannelUpdateKind::Closed => { - session_channels.remove(&update.number); - } - } - } - - while let Some(msg) = state.next_msg_to_send() { - conn.write_all(&msg.to_bytes()) - .await - .wrap_err("writing response")?; + }, + Err(err) => return Err(err), } } } diff --git a/bin/cluelessh/src/main.rs b/bin/cluelessh/src/main.rs index 4c68526..f88e9c9 100644 --- a/bin/cluelessh/src/main.rs +++ b/bin/cluelessh/src/main.rs @@ -8,7 +8,7 @@ use cluelessh_transport::{key::PublicKey, numbers, parse::Writer}; use tokio::net::TcpStream; use tracing::{debug, error}; -use cluelessh_protocol::connection::{ChannelOpen, ChannelOperationKind, ChannelRequest}; +use cluelessh_protocol::connection::{ChannelKind, ChannelOperationKind, ChannelRequest}; use tracing_subscriber::EnvFilter; #[derive(clap::Parser, Debug)] @@ -121,7 +121,7 @@ async fn main() -> eyre::Result<()> { ) .await?; - let session = tokio_conn.open_channel(ChannelOpen::Session); + let session = tokio_conn.open_channel(ChannelKind::Session); tokio::spawn(async { let result = main_channel(session).await; diff --git a/bin/cluelesshd/Cargo.toml b/bin/cluelesshd/Cargo.toml new file mode 100644 index 0000000..6c15fba --- /dev/null +++ b/bin/cluelesshd/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "cluelesshd" +version = "0.1.0" +edition = "2021" + +[dependencies] +cluelessh-protocol = { path = "../../lib/cluelessh-protocol" } +cluelessh-tokio = { path = "../../lib/cluelessh-tokio" } +cluelessh-transport = { path = "../../lib/cluelessh-transport" } diff --git a/bin/cluelesshd/src/main.rs b/bin/cluelesshd/src/main.rs new file mode 100644 index 0000000..e7a11a9 --- /dev/null +++ b/bin/cluelesshd/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world!"); +} diff --git a/lib/cluelessh-connection/src/lib.rs b/lib/cluelessh-connection/src/lib.rs index a11f593..8b5ca91 100644 --- a/lib/cluelessh-connection/src/lib.rs +++ b/lib/cluelessh-connection/src/lib.rs @@ -32,7 +32,7 @@ enum ChannelState { our_window_size: u32, /// For validation only. our_max_packet_size: u32, - update_message: ChannelOpen, + update_message: ChannelKind, }, Open(Channel), } @@ -71,7 +71,7 @@ pub struct ChannelUpdate { pub enum ChannelUpdateKind { Success, Failure, - Open(ChannelOpen), + Open(ChannelKind), OpenFailed { code: u32, message: String }, Request(ChannelRequest), Data { data: Vec }, @@ -80,7 +80,7 @@ pub enum ChannelUpdateKind { Closed, } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum ChannelOpen { +pub enum ChannelKind { Session, } #[derive(Debug)] @@ -173,7 +173,7 @@ impl ChannelsState { debug!(%channel_type, %sender_channel, "Receving channel open"); let update_message = match channel_type { - "session" => ChannelOpen::Session, + "session" => ChannelKind::Session, _ => { self.packets_to_send .push_back(Packet::new_msg_channel_open_failure( @@ -512,7 +512,7 @@ impl ChannelsState { } /// Create a new channel - pub fn create_channel(&mut self, kind: ChannelOpen) -> ChannelNumber { + pub fn create_channel(&mut self, kind: ChannelKind) -> ChannelNumber { let our_number = self.next_channel_id; self.next_channel_id = ChannelNumber( self.next_channel_id @@ -521,7 +521,7 @@ impl ChannelsState { .expect("created too many channels"), ); - assert_eq!(kind, ChannelOpen::Session, "TODO"); + assert_eq!(kind, ChannelKind::Session, "TODO"); let our_window_size = 2097152; // same as OpenSSH let our_max_packet_size = 32768; // same as OpenSSH diff --git a/lib/cluelessh-protocol/src/lib.rs b/lib/cluelessh-protocol/src/lib.rs index b63c858..e15bddd 100644 --- a/lib/cluelessh-protocol/src/lib.rs +++ b/lib/cluelessh-protocol/src/lib.rs @@ -44,8 +44,9 @@ impl ServerConnection { self.transport.send_plaintext_packet(to_send); } if auth.is_authenticated() { - self.state = - ServerConnectionState::Open(cluelessh_connection::ChannelsState::new(true)); + self.state = ServerConnectionState::Open( + cluelessh_connection::ChannelsState::new(true), + ); } } ServerConnectionState::Open(con) => { @@ -94,6 +95,20 @@ impl ServerConnection { } } } + + pub fn channels(&mut self) -> Option<&mut cluelessh_connection::ChannelsState> { + match &mut self.state { + ServerConnectionState::Open(channels) => Some(channels), + _ => None, + } + } + + pub fn auth(&mut self) -> Option<&mut auth::BadAuth> { + match &mut self.state { + ServerConnectionState::Auth(auth) => Some(auth), + _ => None, + } + } } pub struct ClientConnection { @@ -108,7 +123,10 @@ enum ClientConnectionState { } impl ClientConnection { - pub fn new(transport: cluelessh_transport::client::ClientConnection, auth: auth::ClientAuth) -> Self { + pub fn new( + transport: cluelessh_transport::client::ClientConnection, + auth: auth::ClientAuth, + ) -> Self { Self { transport, state: ClientConnectionState::Setup(Some(auth)), @@ -139,8 +157,9 @@ impl ClientConnection { self.transport.send_plaintext_packet(to_send); } if auth.is_authenticated() { - self.state = - ClientConnectionState::Open(cluelessh_connection::ChannelsState::new(false)); + self.state = ClientConnectionState::Open( + cluelessh_connection::ChannelsState::new(false), + ); } } ClientConnectionState::Open(con) => { @@ -227,6 +246,18 @@ pub mod auth { is_authenticated: bool, } + pub enum ServerRequest { + VerifyPassword { + user: String, + password: String, + }, + VerifyPubkey { + session_identifier: [u8; 32], + user: String, + pubkey: Vec, + }, + } + impl BadAuth { pub fn new() -> Self { Self { @@ -320,6 +351,10 @@ pub mod auth { self.is_authenticated } + pub fn server_requests(&mut self) -> impl Iterator + '_ { + [].into_iter() + } + fn queue_packet(&mut self, packet: Packet) { self.packets_to_send.push_back(packet); } diff --git a/lib/cluelessh-tokio/src/client.rs b/lib/cluelessh-tokio/src/client.rs index 88484f5..22d30ba 100644 --- a/lib/cluelessh-tokio/src/client.rs +++ b/lib/cluelessh-tokio/src/client.rs @@ -1,13 +1,15 @@ -use cluelessh_connection::{ChannelNumber, ChannelOpen, ChannelOperation, ChannelOperationKind}; +use cluelessh_connection::{ChannelKind, ChannelNumber, ChannelOperation, ChannelOperationKind}; use std::{collections::HashMap, pin::Pin, sync::Arc}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use cluelessh_protocol::{ChannelUpdateKind, SshStatus}; use eyre::{bail, ContextCompat, OptionExt, Result, WrapErr}; use futures::future::BoxFuture; -use cluelessh_protocol::{ChannelUpdateKind, SshStatus}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, warn}; +use crate::Channel; + pub struct ClientConnection { stream: Pin>, buf: [u8; 1024], @@ -55,11 +57,6 @@ pub struct PendingChannel { ready_recv: tokio::sync::oneshot::Receiver>, channel: Channel, } -pub struct Channel { - number: ChannelNumber, - updates_recv: tokio::sync::mpsc::Receiver, - ops_send: tokio::sync::mpsc::Sender, -} impl ClientConnection { pub async fn connect(stream: S, auth: ClientAuth) -> Result { @@ -75,7 +72,9 @@ impl ClientConnection { channel_ops_recv, channels: HashMap::new(), proto: cluelessh_protocol::ClientConnection::new( - cluelessh_transport::client::ClientConnection::new(cluelessh_protocol::ThreadRngRand), + cluelessh_transport::client::ClientConnection::new( + cluelessh_protocol::ThreadRngRand, + ), cluelessh_protocol::auth::ClientAuth::new(auth.username.as_bytes().to_vec()), ), auth, @@ -245,14 +244,14 @@ impl ClientConnection { Ok(()) } - pub fn open_channel(&mut self, kind: ChannelOpen) -> PendingChannel { + pub fn open_channel(&mut self, kind: ChannelKind) -> PendingChannel { let Some(channels) = self.proto.channels() else { panic!("connection not ready yet") }; let (updates_send, updates_recv) = tokio::sync::mpsc::channel(10); let (ready_send, ready_recv) = tokio::sync::oneshot::channel(); - let number = channels.create_channel(kind); + let number = channels.create_channel(kind.clone()); self.channels.insert( number, @@ -268,6 +267,7 @@ impl ClientConnection { number, updates_recv, ops_send: self.channel_ops_send.clone(), + kind, }, } } @@ -290,11 +290,4 @@ impl Channel { .await .map_err(Into::into) } - - pub async fn next_update(&mut self) -> Result { - self.updates_recv - .recv() - .await - .ok_or_eyre("channel has been closed") - } } diff --git a/lib/cluelessh-tokio/src/lib.rs b/lib/cluelessh-tokio/src/lib.rs index b9babe5..63a1351 100644 --- a/lib/cluelessh-tokio/src/lib.rs +++ b/lib/cluelessh-tokio/src/lib.rs @@ -1 +1,33 @@ pub mod client; +pub mod server; + +use cluelessh_connection::{ChannelKind, ChannelNumber, ChannelOperation, ChannelOperationKind}; +use cluelessh_protocol::ChannelUpdateKind; +use eyre::{OptionExt, Result}; + +pub struct Channel { + number: ChannelNumber, + updates_recv: tokio::sync::mpsc::Receiver, + ops_send: tokio::sync::mpsc::Sender, + kind: ChannelKind, +} + +impl Channel { + pub async fn send(&mut self, op: ChannelOperationKind) -> Result<()> { + self.ops_send + .send(self.number.construct_op(op)) + .await + .map_err(Into::into) + } + + pub async fn next_update(&mut self) -> Result { + self.updates_recv + .recv() + .await + .ok_or_eyre("channel has been closed") + } + + pub fn kind(&self) -> &ChannelKind { + &self.kind + } +} diff --git a/lib/cluelessh-tokio/src/server.rs b/lib/cluelessh-tokio/src/server.rs new file mode 100644 index 0000000..c5e68f8 --- /dev/null +++ b/lib/cluelessh-tokio/src/server.rs @@ -0,0 +1,327 @@ +use cluelessh_connection::{ChannelKind, ChannelNumber, ChannelOperation}; +use std::{ + collections::{HashMap, VecDeque}, + net::SocketAddr, + pin::Pin, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, +}; + +use cluelessh_protocol::{ChannelUpdateKind, SshStatus}; +use eyre::{eyre, ContextCompat, Result, WrapErr}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::info; + +use crate::Channel; + +pub struct ServerListener { + listener: TcpListener, + // todo ratelimits etc +} + +pub struct ServerConnection { + stream: Pin>, + peer_addr: SocketAddr, + buf: [u8; 1024], + + proto: cluelessh_protocol::ServerConnection, + operations_send: tokio::sync::mpsc::Sender, + operations_recv: tokio::sync::mpsc::Receiver, + + /// Cloned and passed on to channels. + channel_ops_send: tokio::sync::mpsc::Sender, + channel_ops_recv: tokio::sync::mpsc::Receiver, + + channels: HashMap, + + /// New channels opened by the peer. + new_channels: VecDeque, +} + +enum ChannelState { + Pending { + ready_send: tokio::sync::oneshot::Sender>, + updates_send: tokio::sync::mpsc::Sender, + }, + Ready(tokio::sync::mpsc::Sender), +} + +enum Operation { + VerifyPassword { + user: String, + password: String, + }, + VerifyPubkey { + session_identifier: [u8; 32], + user: String, + pubkey: Vec, + }, +} + +pub struct SignatureResult { + pub key_alg_name: &'static str, + pub public_key: Vec, + pub signature: Vec, +} + +pub struct PendingChannel { + ready_recv: tokio::sync::oneshot::Receiver>, + channel: Channel, +} +pub enum Error { + SshStatus(SshStatus), + ServerError(eyre::Report), +} +impl From for Error { + fn from(value: eyre::Report) -> Self { + Self::ServerError(value) + } +} + +impl ServerListener { + pub fn new(listener: TcpListener) -> Self { + Self { listener } + } + + pub async fn accept(&mut self) -> Result> { + let (conn, peer_addr) = self.listener.accept().await?; + + Ok(ServerConnection::new(conn, peer_addr)) + } +} + +impl ServerConnection { + pub fn new(stream: S, peer_addr: SocketAddr) -> Self { + let (operations_send, operations_recv) = tokio::sync::mpsc::channel(15); + let (channel_ops_send, channel_ops_recv) = tokio::sync::mpsc::channel(15); + + Self { + stream: Box::pin(stream), + peer_addr, + buf: [0; 1024], + operations_send, + operations_recv, + channel_ops_send, + channel_ops_recv, + channels: HashMap::new(), + proto: cluelessh_protocol::ServerConnection::new( + cluelessh_transport::server::ServerConnection::new( + cluelessh_protocol::ThreadRngRand, + ), + ), + new_channels: VecDeque::new(), + } + } + + pub fn peer_addr(&self) -> SocketAddr { + self.peer_addr + } + + /// Executes one loop iteration of the main loop. + // IMPORTANT: no operations on this struct should ever block the main loop, except this one. + pub async fn progress(&mut self) -> Result<(), Error> { + if let Some(auth) = self.proto.auth() { + for req in auth.server_requests() { + match req { + cluelessh_protocol::auth::ServerRequest::VerifyPassword { user, password } => { + let send = self.operations_send.clone(); + tokio::spawn(async move { + let _ = send + .send(Operation::VerifyPassword { user, password }) + .await; + }); + } + cluelessh_protocol::auth::ServerRequest::VerifyPubkey { + session_identifier, + pubkey, + user, + } => { + let send = self.operations_send.clone(); + tokio::spawn(async move { + let _ = send + .send(Operation::VerifyPubkey { + session_identifier, + user, + pubkey, + }) + .await; + }); + } + } + } + } + + if let Some(channels) = self.proto.channels() { + while let Some(update) = channels.next_channel_update() { + match &update.kind { + ChannelUpdateKind::Open(channel_kind) => { + let channel = self.channels.get_mut(&update.number); + + match channel { + // We opened. + Some(ChannelState::Pending { updates_send, .. }) => { + let updates_send = updates_send.clone(); + let old = self + .channels + .insert(update.number, ChannelState::Ready(updates_send)); + match old.unwrap() { + ChannelState::Pending { ready_send, .. } => { + let _ = ready_send.send(Ok(())); + } + _ => unreachable!(), + } + } + Some(ChannelState::Ready(_)) => { + return Err(Error::ServerError(eyre!( + "attemping to open channel twice: {}", + update.number + ))) + } + // They opened. + None => { + let (updates_send, updates_recv) = tokio::sync::mpsc::channel(10); + + let number = update.number; + + self.channels + .insert(number, ChannelState::Ready(updates_send)); + + let channel = Channel { + number, + updates_recv, + ops_send: self.channel_ops_send.clone(), + kind: channel_kind.clone(), + }; + self.new_channels.push_back(channel); + } + } + } + ChannelUpdateKind::OpenFailed { message, .. } => { + let channel = self + .channels + .get_mut(&update.number) + .wrap_err("unknown channel")?; + match channel { + ChannelState::Pending { .. } => { + let old = self.channels.remove(&update.number); + match old.unwrap() { + ChannelState::Pending { ready_send, .. } => { + let _ = ready_send.send(Err(message.clone())); + } + _ => unreachable!(), + } + } + ChannelState::Ready(_) => { + return Err(Error::ServerError(eyre!( + "attemping to open channel twice: {}", + update.number + ))) + } + } + } + _ => { + let channel = self + .channels + .get_mut(&update.number) + .wrap_err("unknown channel")?; + match channel { + ChannelState::Pending { .. } => { + return Err(Error::ServerError(eyre!("channel not ready yet"))) + } + ChannelState::Ready(updates_send) => { + let _ = updates_send.send(update.kind).await; + } + } + } + } + } + } + + // Make sure that we send all queued messages before going into the select, waiting for things to happen. + self.send_off_data().await?; + + tokio::select! { + read = self.stream.read(&mut self.buf) => { + let read = read.wrap_err("reading from connection")?; + if read == 0 { + info!("Did not read any bytes from TCP stream, EOF"); + return Ok(()); + } + if let Err(err) = self.proto.recv_bytes(&self.buf[..read]) { + return Err(Error::SshStatus(err)); + } + } + channel_op = self.channel_ops_recv.recv() => { + let channels = self.proto.channels().expect("connection not ready"); + if let Some(channel_op) = channel_op { + channels.do_operation(channel_op); + } + } + op = self.operations_recv.recv() => { + match op { + Some(Operation::VerifyPubkey { .. }) => todo!(), + Some(Operation::VerifyPassword { .. }) => todo!(), + None => {} + } + self.send_off_data().await?; + } + } + + Ok(()) + } + + async fn send_off_data(&mut self) -> Result<()> { + self.proto.progress(); + while let Some(msg) = self.proto.next_msg_to_send() { + self.stream + .write_all(&msg.to_bytes()) + .await + .wrap_err("writing response")?; + } + Ok(()) + } + + pub fn open_channel(&mut self, kind: ChannelKind) -> PendingChannel { + let Some(channels) = self.proto.channels() else { + panic!("connection not ready yet") + }; + let (updates_send, updates_recv) = tokio::sync::mpsc::channel(10); + let (ready_send, ready_recv) = tokio::sync::oneshot::channel(); + + let number = channels.create_channel(kind.clone()); + + self.channels.insert( + number, + ChannelState::Pending { + ready_send, + updates_send, + }, + ); + + PendingChannel { + ready_recv, + channel: Channel { + number, + updates_recv, + ops_send: self.channel_ops_send.clone(), + kind, + }, + } + } + + pub fn next_new_channel(&mut self) -> Option { + self.new_channels.pop_front() + } +} + +impl PendingChannel { + pub async fn wait_ready(self) -> Result> { + match self.ready_recv.await { + Ok(Ok(())) => Ok(self.channel), + Ok(Err(err)) => Err(Some(err)), + Err(_) => Err(None), + } + } +}